def test_apart_list(): from sympy.utilities.iterables import numbered_symbols w0, w1, w2 = Symbol("w0"), Symbol("w1"), Symbol("w2") _a = Dummy("a") f = (-2 * x - 2 * x ** 2) / (3 * x ** 2 - 6 * x) assert apart_list(f, x, dummies=numbered_symbols("w")) == ( -1, Poly(S(2) / 3, x, domain="QQ"), [(Poly(w0 - 2, w0, domain="ZZ"), Lambda(_a, 2), Lambda(_a, -_a + x), 1)], ) assert apart_list(2 / (x ** 2 - 2), x, dummies=numbered_symbols("w")) == ( 1, Poly(0, x, domain="ZZ"), [(Poly(w0 ** 2 - 2, w0, domain="ZZ"), Lambda(_a, _a / 2), Lambda(_a, -_a + x), 1)], ) f = 36 / (x ** 5 - 2 * x ** 4 - 2 * x ** 3 + 4 * x ** 2 + x - 2) assert apart_list(f, x, dummies=numbered_symbols("w")) == ( 1, Poly(0, x, domain="ZZ"), [ (Poly(w0 - 2, w0, domain="ZZ"), Lambda(_a, 4), Lambda(_a, -_a + x), 1), (Poly(w1 ** 2 - 1, w1, domain="ZZ"), Lambda(_a, -3 * _a - 6), Lambda(_a, -_a + x), 2), (Poly(w2 + 1, w2, domain="ZZ"), Lambda(_a, -4), Lambda(_a, -_a + x), 1), ], )
def test_apart_list(): from sympy.utilities.iterables import numbered_symbols def dummy_eq(i, j): if type(i) in (list, tuple): return all(dummy_eq(i, j) for i, j in zip(i, j)) return i == j or i.dummy_eq(j) w0, w1, w2 = Symbol("w0"), Symbol("w1"), Symbol("w2") _a = Dummy("a") f = (-2 * x - 2 * x**2) / (3 * x**2 - 6 * x) got = apart_list(f, x, dummies=numbered_symbols("w")) ans = (-1, Poly(Rational(2, 3), x, domain='QQ'), [(Poly(w0 - 2, w0, domain='ZZ'), Lambda(_a, 2), Lambda(_a, -_a + x), 1)]) assert dummy_eq(got, ans) got = apart_list(2 / (x**2 - 2), x, dummies=numbered_symbols("w")) ans = (1, Poly(0, x, domain='ZZ'), [(Poly(w0**2 - 2, w0, domain='ZZ'), Lambda(_a, _a / 2), Lambda(_a, -_a + x), 1)]) assert dummy_eq(got, ans) f = 36 / (x**5 - 2 * x**4 - 2 * x**3 + 4 * x**2 + x - 2) got = apart_list(f, x, dummies=numbered_symbols("w")) ans = (1, Poly(0, x, domain='ZZ'), [ (Poly(w0 - 2, w0, domain='ZZ'), Lambda(_a, 4), Lambda(_a, -_a + x), 1), (Poly(w1**2 - 1, w1, domain='ZZ'), Lambda(_a, -3 * _a - 6), Lambda(_a, -_a + x), 2), (Poly(w2 + 1, w2, domain='ZZ'), Lambda(_a, -4), Lambda(_a, -_a + x), 1) ]) assert dummy_eq(got, ans)
def test_satisfiable_all_models(): from sympy.abc import A, B assert next(satisfiable(False, all_models=True)) is False assert list(satisfiable((A >> ~A) & A, all_models=True)) == [False] assert list(satisfiable(True, all_models=True)) == [{true: true}] models = [{A: True, B: False}, {A: False, B: True}] result = satisfiable(A ^ B, all_models=True) models.remove(next(result)) models.remove(next(result)) raises(StopIteration, lambda: next(result)) assert not models assert list(satisfiable(Equivalent(A, B), all_models=True)) == \ [{A: False, B: False}, {A: True, B: True}] models = [{A: False, B: False}, {A: False, B: True}, {A: True, B: True}] for model in satisfiable(A >> B, all_models=True): models.remove(model) assert not models # This is a santiy test to check that only the required number # of solutions are generated. The expr below has 2**100 - 1 models # which would time out the test if all are generated at once. from sympy.utilities.iterables import numbered_symbols from sympy.logic.boolalg import Or sym = numbered_symbols() X = [next(sym) for i in range(100)] result = satisfiable(Or(*X), all_models=True) for i in range(10): assert next(result)
def test_take(): X = numbered_symbols() assert take(X, 5) == list(symbols('x0:5')) assert take(X, 5) == list(symbols('x5:10')) assert take([1, 2, 3, 4, 5], 5) == [1, 2, 3, 4, 5]
def _eval_expand_trig(self, **hints): arg = self.args[0] x = None if arg.is_Add: from sympy import symmetric_poly n = len(arg.args) TX = [] for x in arg.args: tx = tan(x, evaluate=False)._eval_expand_trig() TX.append(tx) Yg = numbered_symbols('Y') Y = [Yg.next() for i in xrange(n)] p = [0, 0] for i in xrange(n + 1): p[1 - i % 2] += symmetric_poly(i, Y) * (-1)**((i % 4) // 2) return (p[0] / p[1]).subs(zip(Y, TX)) else: coeff, terms = arg.as_coeff_Mul(rational=True) if coeff.is_Integer and coeff > 1: I = S.ImaginaryUnit z = C.Symbol('dummy', real=True) P = ((1 + I * z)**coeff).expand() return (C.im(P) / C.re(P)).subs([(z, tan(terms))]) return tan(arg)
def _eval_expand_trig(self, **hints): arg = self.args[0] x = None if arg.is_Add: from sympy import symmetric_poly n = len(arg.args) CX = [] for x in arg.args: cx = cot(x, evaluate=False)._eval_expand_trig() CX.append(cx) Yg = numbered_symbols("Y") Y = [Yg.next() for i in xrange(n)] p = [0, 0] for i in xrange(n, -1, -1): p[(n - i) % 2] += symmetric_poly(i, Y) * (-1) ** (((n - i) % 4) // 2) return (p[0] / p[1]).subs(zip(Y, CX)) else: coeff, terms = arg.as_coeff_Mul(rational=True) if coeff.is_Integer and coeff > 1: I = S.ImaginaryUnit z = C.Symbol("dummy", real=True) P = ((z + I) ** coeff).expand() return (C.re(P) / C.im(P)).subs([(z, cot(terms))]) return cot(arg)
def _eval_expand_trig(self, **hints): arg = self.args[0] x = None if arg.is_Add: from sympy import symmetric_poly n = len(arg.args) TX = [] for x in arg.args: tx = tan(x, evaluate=False)._eval_expand_trig() TX.append(tx) Yg = numbered_symbols('Y') Y = [ Yg.next() for i in xrange(n) ] p = [0,0] for i in xrange(n+1): p[1-i%2] += symmetric_poly(i,Y)*(-1)**((i%4)//2) return (p[0]/p[1]).subs(zip(Y,TX)) else: coeff, terms = arg.as_coeff_Mul(rational=True) if coeff.is_Integer and coeff > 1: I = S.ImaginaryUnit z = C.Symbol('dummy',real=True) P = ((1+I*z)**coeff).expand() return (C.im(P)/C.re(P)).subs([(z,tan(terms))]) return tan(arg)
def test_apart_list(): from sympy.utilities.iterables import numbered_symbols w0, w1, w2 = Symbol("w0"), Symbol("w1"), Symbol("w2") _a = Dummy("a") f = (-2 * x - 2 * x**2) / (3 * x**2 - 6 * x) assert apart_list(f, x, dummies=numbered_symbols("w")) == ( -1, Poly(Rational(2, 3), x, domain="QQ"), [(Poly(w0 - 2, w0, domain="ZZ"), Lambda(_a, 2), Lambda(_a, -_a + x), 1)], ) assert apart_list(2 / (x**2 - 2), x, dummies=numbered_symbols("w")) == ( 1, Poly(0, x, domain="ZZ"), [( Poly(w0**2 - 2, w0, domain="ZZ"), Lambda(_a, _a / 2), Lambda(_a, -_a + x), 1, )], ) f = 36 / (x**5 - 2 * x**4 - 2 * x**3 + 4 * x**2 + x - 2) assert apart_list(f, x, dummies=numbered_symbols("w")) == ( 1, Poly(0, x, domain="ZZ"), [ (Poly(w0 - 2, w0, domain="ZZ"), Lambda(_a, 4), Lambda(_a, -_a + x), 1), ( Poly(w1**2 - 1, w1, domain="ZZ"), Lambda(_a, -3 * _a - 6), Lambda(_a, -_a + x), 2, ), (Poly(w2 + 1, w2, domain="ZZ"), Lambda(_a, -4), Lambda(_a, -_a + x), 1), ], )
def cse(exprs, symbols=None, optimizations=None): """ Perform common subexpression elimination on an expression. :arg exprs: A list of sympy expressions, or a single sympy expression to reduce :arg symbols: An iterator yielding unique Symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` generator from sympy is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. :arg optimizations: A list of (callable, callable) pairs consisting of (preprocessor, postprocessor) pairs of external optimization functions. :return: This returns a pair ``(replacements, reduced_exprs)``. * ``replacements`` is a list of (Symbol, expression) pairs consisting of all of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. * ``reduced_exprs`` is a list of sympy expressions. This contains the reduced expressions with all of the replacements above. """ if isinstance(exprs, Basic): exprs = [exprs] exprs = list(exprs) if optimizations is None: optimizations = [] # Preprocess the expressions to give us better optimization opportunities. reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] if symbols is None: symbols = numbered_symbols(cls=Symbol) else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) # Find other optimization opportunities. opt_subs = opt_cse(reduced_exprs) # Main CSE algorithm. replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [ postprocess_for_cse(e, optimizations) for e in reduced_exprs ] return replacements, reduced_exprs
def build_cse_fn(symname, symfunc, symbolslist): tmpsyms = numbered_symbols("R") symbols, simple = cse(symfunc, symbols=tmpsyms) code = "double %s(%s)\n" % (str(symname), ", ".join( "double const& %s" % x for x in symbolslist)) code += "{\n" for s in symbols: code += " double %s = %s;\n" % (ccode(s[0]), ccode(s[1])) code += " double result = %s;\n" % ccode(simple[0]) code += " return result;\n" code += "}\n" return code
def cse(exprs, symbols=None, optimizations=None): """ Perform common subexpression elimination on an expression. Parameters: exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The `numbered_symbols` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs, optional The (preprocessor, postprocessor) pairs. If not provided, `sympy.simplify.cse.cse_optimizations` is used. Returns: replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. """ if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) seen_subexp = set() to_eliminate = [] if optimizations is None: # Pull out the default here just in case there are some weird # manipulations of the module-level list in some other thread. optimizations = list(cse_optimizations) # Handle the case if just one expression was passed. if isinstance(exprs, Basic): exprs = [exprs] # Preprocess the expressions to give us better optimization opportunities. exprs = [preprocess_for_cse(e, optimizations) for e in exprs] # Find all of the repeated subexpressions. for expr in exprs: for subtree in postorder_traversal(expr): if subtree.args == (): # Exclude atoms, since there is no point in renaming them. continue if (subtree.args != () and subtree in seen_subexp and subtree not in to_eliminate): to_eliminate.append(subtree) seen_subexp.add(subtree) # Substitute symbols for all of the repeated subexpressions. replacements = [] reduced_exprs = list(exprs) for i, subtree in enumerate(to_eliminate): sym = symbols.next() replacements.append((sym, subtree)) # Make the substitution in all of the target expressions. for j, expr in enumerate(reduced_exprs): reduced_exprs[j] = expr.subs(subtree, sym) # Make the substitution in all of the subsequent substitutions. # WARNING: modifying iterated list in-place! I think it's fine, # but there might be clearer alternatives. for j in range(i + 1, len(to_eliminate)): to_eliminate[j] = to_eliminate[j].subs(subtree, sym) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [ postprocess_for_cse(e, optimizations) for e in reduced_exprs ] return replacements, reduced_exprs
def test_numbered_symbols(): s = numbered_symbols(cls=Dummy) assert isinstance(next(s), Dummy) assert next(numbered_symbols('C', start=1, exclude=[symbols('C1')])) == \ symbols('C2')
def extract_sub_expressions(self, cache_prefix='cache', sub_prefix='sub', prefix='XoXoXoX'): # Do the common sub expression elimination. common_sub_expressions, expression_substituted_list = sym.cse(self.expression_list, numbered_symbols(prefix=prefix)) self.variables[cache_prefix] = [] self.variables[sub_prefix] = [] # Create dictionary of new sub expressions sub_expression_dict = {} for var, void in common_sub_expressions: sub_expression_dict[var.name] = var # Sort out any expression that's dependent on something that scales with data size (these are listed in cacheable). cacheable_list = [] params_change_list = [] # common_sube_expressions contains a list of paired tuples with the new variable and what it equals for var, expr in common_sub_expressions: arg_list = [e for e in expr.atoms() if e.is_Symbol] # List any cacheable dependencies of the sub-expression cacheable_symbols = [e for e in arg_list if e in cacheable_list or e in self.cacheable_vars] if cacheable_symbols: # list which ensures dependencies are cacheable. cacheable_list.append(var) else: params_change_list.append(var) replace_dict = {} for i, expr in enumerate(cacheable_list): sym_var = sym.var(cache_prefix + str(i)) self.variables[cache_prefix].append(sym_var) replace_dict[expr.name] = sym_var for i, expr in enumerate(params_change_list): sym_var = sym.var(sub_prefix + str(i)) self.variables[sub_prefix].append(sym_var) replace_dict[expr.name] = sym_var for replace, void in common_sub_expressions: for expr, keys in zip(expression_substituted_list, self.expression_keys): setInDict(self.expressions, keys, expr.subs(replace, replace_dict[replace.name])) for void, expr in common_sub_expressions: expr = expr.subs(replace, replace_dict[replace.name]) # Replace original code with code including subexpressions. for keys in self.expression_keys: for replace, void in common_sub_expressions: setInDict(self.expressions, keys, getFromDict(self.expressions, keys).subs(replace, replace_dict[replace.name])) self.expressions['parameters_changed'] = {} self.expressions['update_cache'] = {} for var, expr in common_sub_expressions: for replace, void in common_sub_expressions: expr = expr.subs(replace, replace_dict[replace.name]) if var in cacheable_list: self.expressions['update_cache'][replace_dict[var.name].name] = expr else: self.expressions['parameters_changed'][replace_dict[var.name].name] = expr
def test_numbered_symbols(): s = numbered_symbols(cls=Dummy) assert isinstance(next(s), Dummy) assert next(numbered_symbols("C", start=1, exclude=[symbols("C1")])) == symbols( "C2" )
def __init__( self, args, expr, print_lambda=False, use_evalf=False, float_wrap_evalf=False, complex_wrap_evalf=False, use_np=False, use_python_math=False, use_python_cmath=False, use_interval=False, ): self.print_lambda = print_lambda self.use_evalf = use_evalf self.float_wrap_evalf = float_wrap_evalf self.complex_wrap_evalf = complex_wrap_evalf self.use_np = use_np self.use_python_math = use_python_math self.use_python_cmath = use_python_cmath self.use_interval = use_interval # Constructing the argument string # - check if not all([isinstance(a, Symbol) for a in args]): raise ValueError("The arguments must be Symbols.") # - use numbered symbols syms = numbered_symbols(exclude=expr.free_symbols) newargs = [next(syms) for _ in args] expr = expr.xreplace(dict(zip(args, newargs))) argstr = ", ".join([str(a) for a in newargs]) del syms, newargs, args # Constructing the translation dictionaries and making the translation self.dict_str = self.get_dict_str() self.dict_fun = self.get_dict_fun() exprstr = str(expr) newexpr = self.tree2str_translate(self.str2tree(exprstr)) # Constructing the namespaces namespace = {} namespace.update(self.sympy_atoms_namespace(expr)) namespace.update(self.sympy_expression_namespace(expr)) # XXX Workaround # Ugly workaround because Pow(a,Half) prints as sqrt(a) # and sympy_expression_namespace can not catch it. from sympy import sqrt namespace.update({"sqrt": sqrt}) namespace.update({"Eq": lambda x, y: x == y}) namespace.update({"Ne": lambda x, y: x != y}) # End workaround. if use_python_math: namespace.update({"math": __import__("math")}) if use_python_cmath: namespace.update({"cmath": __import__("cmath")}) if use_np: try: namespace.update({"np": __import__("numpy")}) except ImportError: raise ImportError( "experimental_lambdify failed to import numpy.") if use_interval: namespace.update({ "imath": __import__("sympy.plotting.intervalmath", fromlist=["intervalmath"]) }) namespace.update({"math": __import__("math")}) # Construct the lambda if self.print_lambda: print(newexpr) eval_str = "lambda %s : ( %s )" % (argstr, newexpr) self.eval_str = eval_str exec_("from __future__ import division; MYNEWLAMBDA = %s" % eval_str, namespace) self.lambda_func = namespace["MYNEWLAMBDA"]
def _eval_rewrite_as_sqrt(self, arg): _EXPAND_INTS = False def migcdex(x): # recursive calcuation of gcd and linear combination # for a sequence of integers. # Given (x1, x2, x3) # Returns (y1, y1, y3, g) # such that g is the gcd and x1*y1+x2*y2+x3*y3 - g = 0 # Note, that this is only one such linear combination. if len(x) == 1: return (1, x[0]) if len(x) == 2: return igcdex(x[0], x[-1]) g = migcdex(x[1:]) u, v, h = igcdex(x[0], g[-1]) return tuple([u] + [v * i for i in g[0:-1]] + [h]) def ipartfrac(r, factors=None): if isinstance(r, int): return r assert isinstance(r, C.Rational) n = r.q if 2 > r.q * r.q: return r.q if None == factors: a = [n // x**y for x, y in factorint(r.q).iteritems()] else: a = [n // x for x in factors] if len(a) == 1: return [r] h = migcdex(a) ans = [r.p * C.Rational(i * j, r.q) for i, j in zip(h[:-1], a)] assert r == sum(ans) return ans pi_coeff = _pi_coeff(arg) if pi_coeff is None: return None assert not pi_coeff.is_integer, "should have been simplified already" if not pi_coeff.is_Rational: return None cst_table_some = { 3: S.Half, 5: (sqrt(5) + 1) / 4, 17: sqrt((15 + sqrt(17)) / 32 + sqrt(2) * (sqrt(17 - sqrt(17)) + sqrt( sqrt(2) * (-8 * sqrt(17 + sqrt(17)) - (1 - sqrt(17)) * sqrt(17 - sqrt(17))) + 6 * sqrt(17) + 34)) / 32) # 65537 and 257 are the only other known Fermat primes # Please add if you would like them } def fermatCoords(n): assert isinstance(n, int) assert n > 0 if n == 1 or 0 == n % 2: return False primes = dict([(p, 0) for p in cst_table_some]) assert 1 not in primes for p_i in primes: while 0 == n % p_i: n = n / p_i primes[p_i] += 1 if 1 != n: return False if max(primes.values()) > 1: return False return tuple([p for p in primes if primes[p] == 1]) if pi_coeff.q in cst_table_some: return C.chebyshevt(pi_coeff.p, cst_table_some[pi_coeff.q]).expand() if 0 == pi_coeff.q % 2: # recursively remove powers of 2 narg = (pi_coeff * 2) * S.Pi nval = cos(narg) if None == nval: return None nval = nval.rewrite(sqrt) if not _EXPAND_INTS: if (isinstance(nval, cos) or isinstance(-nval, cos)): return None x = (2 * pi_coeff + 1) / 2 sign_cos = (-1)**((-1 if x < 0 else 1) * int(abs(x))) return sign_cos * sqrt((1 + nval) / 2) FC = fermatCoords(pi_coeff.q) if FC: decomp = ipartfrac(pi_coeff, FC) X = [(x[1], x[0] * S.Pi) for x in zip(decomp, numbered_symbols('z'))] pcls = cos(sum([x[0] for x in X]))._eval_expand_trig().subs(X) return pcls.rewrite(sqrt) if _EXPAND_INTS: decomp = ipartfrac(pi_coeff) X = [(x[1], x[0] * S.Pi) for x in zip(decomp, numbered_symbols('z'))] pcls = cos(sum([x[0] for x in X]))._eval_expand_trig().subs(X) return pcls return None
def cse(exprs, symbols=None, optimizations=None, postprocess=None, order='canonical', ignore=()): """ Perform common subexpression elimination on an expression. Parameters ========== exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs The (preprocessor, postprocessor) pairs of external optimization functions. Optionally 'basic' can be passed for a set of predefined basic optimizations. Such 'basic' optimizations were used by default in old implementation, however they can be really slow on larger expressions. Now, no pre or post optimizations are made by default. postprocess : a function which accepts the two return values of cse and returns the desired form of output from cse, e.g. if you want the replacements reversed the function might be the following lambda: lambda r, e: return reversed(r), e order : string, 'none' or 'canonical' The order by which Mul and Add arguments are processed. If set to 'canonical', arguments will be canonically ordered. If set to 'none', ordering will be faster but dependent on expressions hashes, thus machine dependent and variable. For large expressions where speed is a concern, use the setting order='none'. ignore : iterable of Symbols Substitutions containing any Symbol from ``ignore`` will be ignored. Returns ======= replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. Examples ======== >>> from sympy import cse, SparseMatrix >>> from sympy.abc import x, y, z, w >>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3) ([(x0, w + y + z)], [x0*(x + x0)/(w + x)**3]) Note that currently, y + z will not get substituted if -y - z is used. >>> cse(((w + x + y + z)*(w - y - z))/(w + x)**3) ([(x0, w + x)], [(w - y - z)*(x0 + y + z)/x0**3]) List of expressions with recursive substitutions: >>> m = SparseMatrix([x + y, x + y + z]) >>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m]) ([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([ [x0], [x1]])]) Note: the type and mutability of input matrices is retained. >>> isinstance(_[1][-1], SparseMatrix) True The user may disallow substitutions containing certain symbols: >>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,)) ([(x0, x + 1)], [x0*y**2, 3*x0*y**2]) """ from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix, SparseMatrix, ImmutableSparseMatrix) # Handle the case if just one expression was passed. if isinstance(exprs, (Basic, MatrixBase)): exprs = [exprs] copy = exprs temp = [] for e in exprs: if isinstance(e, (Matrix, ImmutableMatrix)): temp.append(Tuple(*e._mat)) elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): temp.append(Tuple(*e._smat.items())) else: temp.append(e) exprs = temp del temp if optimizations is None: optimizations = list() elif optimizations == 'basic': optimizations = basic_optimizations # Preprocess the expressions to give us better optimization opportunities. reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] excluded_symbols = set().union(*[expr.atoms(Symbol) for expr in reduced_exprs]) if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) symbols = filter_symbols(symbols, excluded_symbols) # Find other optimization opportunities. opt_subs = opt_cse(reduced_exprs, order) # Main CSE algorithm. replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs, order, ignore) # Postprocess the expressions to return the expressions to canonical form. exprs = copy for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] # Get the matrices back for i, e in enumerate(exprs): if isinstance(e, (Matrix, ImmutableMatrix)): reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i]) if isinstance(e, ImmutableMatrix): reduced_exprs[i] = reduced_exprs[i].as_immutable() elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): m = SparseMatrix(e.rows, e.cols, {}) for k, v in reduced_exprs[i]: m[k] = v if isinstance(e, ImmutableSparseMatrix): m = m.as_immutable() reduced_exprs[i] = m if postprocess is None: return replacements, reduced_exprs return postprocess(replacements, reduced_exprs)
def cse(exprs, symbols=None, optimizations=None, postprocess=None): """ Perform common subexpression elimination on an expression. Parameters ========== exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs, optional The (preprocessor, postprocessor) pairs. If not provided, ``sympy.simplify.cse.cse_optimizations`` is used. postprocess : a function which accepts the two return values of cse and returns the desired form of output from cse, e.g. if you want the replacements reversed the function might be the following lambda: lambda r, e: return reversed(r), e Returns ======= replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. """ from sympy.matrices import Matrix if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) seen_subexp = set() muls = set() adds = set() to_eliminate = set() if optimizations is None: # Pull out the default here just in case there are some weird # manipulations of the module-level list in some other thread. optimizations = list(cse_optimizations) # Handle the case if just one expression was passed. if isinstance(exprs, Basic): exprs = [exprs] # Preprocess the expressions to give us better optimization opportunities. reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] # Find all of the repeated subexpressions. for expr in reduced_exprs: if not isinstance(expr, Basic): continue pt = preorder_traversal(expr) for subtree in pt: inv = 1 / subtree if subtree.is_Pow else None if subtree.is_Atom or iterable(subtree) or inv and inv.is_Atom: # Exclude atoms, since there is no point in renaming them. continue if subtree in seen_subexp: if inv and _coeff_isneg(subtree.exp): # save the form with positive exponent subtree = inv to_eliminate.add(subtree) pt.skip() continue if inv and inv in seen_subexp: if _coeff_isneg(subtree.exp): # save the form with positive exponent subtree = inv to_eliminate.add(subtree) pt.skip() continue elif subtree.is_Mul: muls.add(subtree) elif subtree.is_Add: adds.add(subtree) seen_subexp.add(subtree) # process adds - any adds that weren't repeated might contain # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common adds = [set(a.args) for a in ordered(adds)] for i in xrange(len(adds)): for j in xrange(i + 1, len(adds)): com = adds[i].intersection(adds[j]) if len(com) > 1: to_eliminate.add(Add(*com)) # remove this set of symbols so it doesn't appear again adds[i] = adds[i].difference(com) adds[j] = adds[j].difference(com) for k in xrange(j + 1, len(adds)): if not com.difference(adds[k]): adds[k] = adds[k].difference(com) # process muls - any muls that weren't repeated might contain # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common # use SequenceMatcher on the nc part to find the longest common expression # in common between the two nc parts sm = difflib.SequenceMatcher() muls = [a.args_cnc(cset=True) for a in ordered(muls)] for i in xrange(len(muls)): if muls[i][1]: sm.set_seq1(muls[i][1]) for j in xrange(i + 1, len(muls)): # the commutative part in common ccom = muls[i][0].intersection(muls[j][0]) # the non-commutative part in common if muls[i][1] and muls[j][1]: # see if there is any chance of an nc match ncom = set(muls[i][1]).intersection(set(muls[j][1])) if len(ccom) + len(ncom) < 2: continue # now work harder to find the match sm.set_seq2(muls[j][1]) i1, _, n = sm.find_longest_match(0, len(muls[i][1]), 0, len(muls[j][1])) ncom = muls[i][1][i1:i1 + n] else: ncom = [] com = list(ccom) + ncom if len(com) < 2: continue to_eliminate.add(Mul(*com)) # remove ccom from all if there was no ncom; to update the nc part # would require finding the subexpr and then replacing it with a # dummy to keep bounding nc symbols from being identified as a # subexpr, e.g. removing B*C from A*B*C*D might allow A*D to be # identified as a subexpr which would not be right. if not ncom: muls[i][0] = muls[i][0].difference(ccom) for k in xrange(j, len(muls)): if not ccom.difference(muls[k][0]): muls[k][0] = muls[k][0].difference(ccom) # make to_eliminate canonical; we will prefer non-Muls to Muls # so select them first (non-Muls will have False for is_Mul and will # be first in the ordering. to_eliminate = list(ordered(to_eliminate, lambda _: _.is_Mul)) # Substitute symbols for all of the repeated subexpressions. replacements = [] reduced_exprs = list(reduced_exprs) hit = True for i, subtree in enumerate(to_eliminate): if hit: sym = next(symbols) hit = False if subtree.is_Pow and subtree.exp.is_Rational: update = lambda x: x.xreplace({subtree: sym, 1 / subtree: 1 / sym}) else: update = lambda x: x.subs(subtree, sym) # Make the substitution in all of the target expressions. for j, expr in enumerate(reduced_exprs): old = reduced_exprs[j] reduced_exprs[j] = update(expr) hit = hit or (old != reduced_exprs[j]) # Make the substitution in all of the subsequent substitutions. for j in range(i + 1, len(to_eliminate)): old = to_eliminate[j] to_eliminate[j] = update(to_eliminate[j]) hit = hit or (old != to_eliminate[j]) if hit: replacements.append((sym, subtree)) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [ postprocess_for_cse(e, optimizations) for e in reduced_exprs ] # remove replacements that weren't used more than once _remove_singletons(replacements, reduced_exprs) if isinstance(exprs, Matrix): reduced_exprs = [Matrix(exprs.rows, exprs.cols, reduced_exprs)] if postprocess is None: return replacements, reduced_exprs return postprocess(replacements, reduced_exprs)
def __init__(self, args, expr, print_lambda=False, use_evalf=False, float_wrap_evalf=False, complex_wrap_evalf=False, use_np=False, use_python_math=False, use_python_cmath=False, use_interval=False): self.print_lambda = print_lambda self.use_evalf = use_evalf self.float_wrap_evalf = float_wrap_evalf self.complex_wrap_evalf = complex_wrap_evalf self.use_np = use_np self.use_python_math = use_python_math self.use_python_cmath = use_python_cmath self.use_interval = use_interval # Constructing the argument string # - check if not all([isinstance(a, Symbol) for a in args]): raise ValueError('The arguments must be Symbols.') # - use numbered symbols syms = numbered_symbols(exclude=expr.free_symbols) newargs = [next(syms) for i in args] expr = expr.xreplace(dict(zip(args, newargs))) argstr = ', '.join([str(a) for a in newargs]) del syms, newargs, args # Constructing the translation dictionaries and making the translation self.dict_str = self.get_dict_str() self.dict_fun = self.get_dict_fun() exprstr = str(expr) # the & and | operators don't work on tuples, see discussion #12108 exprstr = exprstr.replace(" & "," and ").replace(" | "," or ") newexpr = self.tree2str_translate(self.str2tree(exprstr)) # Constructing the namespaces namespace = {} namespace.update(self.sympy_atoms_namespace(expr)) namespace.update(self.sympy_expression_namespace(expr)) # XXX Workaround # Ugly workaround because Pow(a,Half) prints as sqrt(a) # and sympy_expression_namespace can not catch it. from sympy import sqrt namespace.update({'sqrt': sqrt}) namespace.update({'Eq': lambda x, y: x == y}) # End workaround. if use_python_math: namespace.update({'math': __import__('math')}) if use_python_cmath: namespace.update({'cmath': __import__('cmath')}) if use_np: try: namespace.update({'np': __import__('numpy')}) except ImportError: raise ImportError( 'experimental_lambdify failed to import numpy.') if use_interval: namespace.update({'imath': __import__( 'sympy.plotting.intervalmath', fromlist=['intervalmath'])}) namespace.update({'math': __import__('math')}) # Construct the lambda if self.print_lambda: print(newexpr) eval_str = 'lambda %s : ( %s )' % (argstr, newexpr) self.eval_str = eval_str exec_("from __future__ import division; MYNEWLAMBDA = %s" % eval_str, namespace) self.lambda_func = namespace['MYNEWLAMBDA']
def cse(exprs, symbols=None, optimizations=None, postprocess=None): """ Perform common subexpression elimination on an expression. Parameters ========== exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs, optional The (preprocessor, postprocessor) pairs. If not provided, ``sympy.simplify.cse.cse_optimizations`` is used. postprocess : a function which accepts the two return values of cse and returns the desired form of output from cse, e.g. if you want the replacements reversed the function might be the following lambda: lambda r, e: return reversed(r), e Returns ======= replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. """ from sympy.matrices import Matrix if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) tmp_symbols = numbered_symbols('_csetmp') subexp_iv = dict() muls = set() adds = set() if optimizations is None: # Pull out the default here just in case there are some weird # manipulations of the module-level list in some other thread. optimizations = list(cse_optimizations) # Handle the case if just one expression was passed. if isinstance(exprs, Basic): exprs = [exprs] # Preprocess the expressions to give us better optimization opportunities. prep_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] # Find all subexpressions. def _parse(expr): if expr.is_Atom: # Exclude atoms, since there is no point in renaming them. return expr if iterable(expr): return expr subexpr = type(expr)(*map(_parse, expr.args)) if subexpr in subexp_iv: return subexp_iv[subexpr] if subexpr.is_Mul: muls.add(subexpr) elif subexpr.is_Add: adds.add(subexpr) ivar = next(tmp_symbols) subexp_iv[subexpr] = ivar return ivar tmp_exprs = list() for expr in prep_exprs: if isinstance(expr, Basic): tmp_exprs.append(_parse(expr)) else: tmp_exprs.append(expr) # process adds - any adds that weren't repeated might contain # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common adds = list(ordered(adds)) addargs = [set(a.args) for a in adds] for i in xrange(len(addargs)): for j in xrange(i + 1, len(addargs)): com = addargs[i].intersection(addargs[j]) if len(com) > 1: add_subexp = Add(*com) diff_add_i = addargs[i].difference(com) diff_add_j = addargs[j].difference(com) if add_subexp in subexp_iv: ivar = subexp_iv[add_subexp] else: ivar = next(tmp_symbols) subexp_iv[add_subexp] = ivar if diff_add_i: newadd = Add(ivar,*diff_add_i) subexp_iv[newadd] = subexp_iv.pop(adds[i]) adds[i] = newadd #else add_i is itself subexp_iv[add_subexp] -> ivar if diff_add_j: newadd = Add(ivar,*diff_add_j) subexp_iv[newadd] = subexp_iv.pop(adds[j]) adds[j] = newadd #else add_j is itself subexp_iv[add_subexp] -> ivar addargs[i] = diff_add_i addargs[j] = diff_add_j for k in xrange(j + 1, len(addargs)): if com.issubset(addargs[k]): diff_add_k = addargs[k].difference(com) if diff_add_k: newadd = Add(ivar,*diff_add_k) subexp_iv[newadd] = subexp_iv.pop(adds[k]) adds[k] = newadd #else add_k is itself subexp_iv[add_subexp] -> ivar addargs[k] = diff_add_k # process muls - any muls that weren't repeated might contain # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common # *assumes that there are no non-commutative parts* muls = list(ordered(muls)) mulargs = [set(a.args) for a in muls] for i in xrange(len(mulargs)): for j in xrange(i + 1, len(mulargs)): com = mulargs[i].intersection(mulargs[j]) if len(com) > 1: mul_subexp = Mul(*com) diff_mul_i = mulargs[i].difference(com) diff_mul_j = mulargs[j].difference(com) if mul_subexp in subexp_iv: ivar = subexp_iv[mul_subexp] else: ivar = next(tmp_symbols) subexp_iv[mul_subexp] = ivar if diff_mul_i: newmul = Mul(ivar,*diff_mul_i) subexp_iv[newmul] = subexp_iv.pop(muls[i]) muls[i] = newmul #else mul_i is itself subexp_iv[mul_subexp] -> ivar if diff_mul_j: newmul = Mul(ivar,*diff_mul_j) subexp_iv[newmul] = subexp_iv.pop(muls[j]) muls[j] = newmul #else mul_j is itself subexp_iv[mul_subexp] -> ivar mulargs[i] = diff_mul_i mulargs[j] = diff_mul_j for k in xrange(j + 1, len(mulargs)): if com.issubset(mulargs[k]): diff_mul_k = mulargs[k].difference(com) if diff_mul_k: newmul = Mul(ivar,*diff_mul_k) subexp_iv[newmul] = subexp_iv.pop(muls[k]) muls[k] = newmul #else mul_k is itself subexp_iv[mul_subexp] -> ivar mulargs[k] = diff_mul_k # Find all of the repeated subexpressions. ivar_se = {iv:se for se,iv in subexp_iv.iteritems()} used_ivs = set() repeated = set() def _find_repeated_subexprs(subexpr): if subexpr.is_Atom: symbs = [subexpr] else: symbs = subexpr.args for symb in symbs: if symb in ivar_se: if symb not in used_ivs: _find_repeated_subexprs(ivar_se[symb]) used_ivs.add(symb) else: repeated.add(symb) for expr in tmp_exprs: _find_repeated_subexprs(expr) # Substitute symbols for all of the repeated subexpressions. # remove temporary replacements that weren't used more than once tmpivs_ivs = dict() ordered_iv_se = OrderedDict() def _get_subexprs(args): args = list(args) for i,symb in enumerate(args): if symb in ivar_se: if symb in tmpivs_ivs: args[i] = tmpivs_ivs[symb] else: subexpr = ivar_se[symb] subexpr = type(subexpr)(*_get_subexprs(subexpr.args)) if symb in repeated: ivar = next(symbols) ordered_iv_se[ivar] = subexpr tmpivs_ivs[symb] = ivar args[i] = ivar else: args[i] = subexpr return args out_exprs = _get_subexprs(tmp_exprs) # Postprocess the expressions to return the expressions to canonical form. ordered_iv_se_notopt = ordered_iv_se ordered_iv_se = OrderedDict() for i, (ivar, subexpr) in enumerate(ordered_iv_se_notopt.items()): subexpr = postprocess_for_cse(subexpr, optimizations) ordered_iv_se[ivar] = subexpr out_exprs = [postprocess_for_cse(e, optimizations) for e in out_exprs] if isinstance(exprs, Matrix): out_exprs = Matrix(exprs.rows, exprs.cols, out_exprs) if postprocess is None: return ordered_iv_se.items(), out_exprs return postprocess(ordered_iv_se.items(), out_exprs)
def cse(exprs, symbols=None, optimizations=None, postprocess=None): """ Perform common subexpression elimination on an expression. Parameters ========== exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs, optional The (preprocessor, postprocessor) pairs. If not provided, ``sympy.simplify.cse.cse_optimizations`` is used. postprocess : a function which accepts the two return values of cse and returns the desired form of output from cse, e.g. if you want the replacements reversed the function might be the following lambda: lambda r, e: return reversed(r), e Returns ======= replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. """ from sympy.matrices import Matrix if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) seen_subexp = set() muls = set() adds = set() to_eliminate = set() if optimizations is None: # Pull out the default here just in case there are some weird # manipulations of the module-level list in some other thread. optimizations = list(cse_optimizations) # Handle the case if just one expression was passed. if isinstance(exprs, Basic): exprs = [exprs] # Preprocess the expressions to give us better optimization opportunities. reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] # Find all of the repeated subexpressions. for expr in reduced_exprs: if not isinstance(expr, Basic): continue pt = preorder_traversal(expr) for subtree in pt: inv = 1/subtree if subtree.is_Pow else None if subtree.is_Atom or iterable(subtree) or inv and inv.is_Atom: # Exclude atoms, since there is no point in renaming them. continue if subtree in seen_subexp: if inv and _coeff_isneg(subtree.exp): # save the form with positive exponent subtree = inv to_eliminate.add(subtree) pt.skip() continue if inv and inv in seen_subexp: if _coeff_isneg(subtree.exp): # save the form with positive exponent subtree = inv to_eliminate.add(subtree) pt.skip() continue elif subtree.is_Mul: muls.add(subtree) elif subtree.is_Add: adds.add(subtree) seen_subexp.add(subtree) # process adds - any adds that weren't repeated might contain # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common adds = [set(a.args) for a in ordered(adds)] for i in xrange(len(adds)): for j in xrange(i + 1, len(adds)): com = adds[i].intersection(adds[j]) if len(com) > 1: to_eliminate.add(Add(*com)) # remove this set of symbols so it doesn't appear again adds[i] = adds[i].difference(com) adds[j] = adds[j].difference(com) for k in xrange(j + 1, len(adds)): if not com.difference(adds[k]): adds[k] = adds[k].difference(com) # process muls - any muls that weren't repeated might contain # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common # use SequenceMatcher on the nc part to find the longest common expression # in common between the two nc parts sm = difflib.SequenceMatcher() muls = [a.args_cnc(cset=True) for a in ordered(muls)] for i in xrange(len(muls)): if muls[i][1]: sm.set_seq1(muls[i][1]) for j in xrange(i + 1, len(muls)): # the commutative part in common ccom = muls[i][0].intersection(muls[j][0]) # the non-commutative part in common if muls[i][1] and muls[j][1]: # see if there is any chance of an nc match ncom = set(muls[i][1]).intersection(set(muls[j][1])) if len(ccom) + len(ncom) < 2: continue # now work harder to find the match sm.set_seq2(muls[j][1]) i1, _, n = sm.find_longest_match(0, len(muls[i][1]), 0, len(muls[j][1])) ncom = muls[i][1][i1:i1 + n] else: ncom = [] com = list(ccom) + ncom if len(com) < 2: continue to_eliminate.add(Mul(*com)) # remove ccom from all if there was no ncom; to update the nc part # would require finding the subexpr and then replacing it with a # dummy to keep bounding nc symbols from being identified as a # subexpr, e.g. removing B*C from A*B*C*D might allow A*D to be # identified as a subexpr which would not be right. if not ncom: muls[i][0] = muls[i][0].difference(ccom) for k in xrange(j, len(muls)): if not ccom.difference(muls[k][0]): muls[k][0] = muls[k][0].difference(ccom) # make to_eliminate canonical; we will prefer non-Muls to Muls # so select them first (non-Muls will have False for is_Mul and will # be first in the ordering. to_eliminate = list(ordered(to_eliminate, lambda _: _.is_Mul)) # Substitute symbols for all of the repeated subexpressions. replacements = [] reduced_exprs = list(reduced_exprs) hit = True for i, subtree in enumerate(to_eliminate): if hit: sym = next(symbols) hit = False if subtree.is_Pow and subtree.exp.is_Rational: update = lambda x: x.xreplace({subtree: sym, 1/subtree: 1/sym}) else: update = lambda x: x.subs(subtree, sym) # Make the substitution in all of the target expressions. for j, expr in enumerate(reduced_exprs): old = reduced_exprs[j] reduced_exprs[j] = update(expr) hit = hit or (old != reduced_exprs[j]) # Make the substitution in all of the subsequent substitutions. for j in range(i + 1, len(to_eliminate)): old = to_eliminate[j] to_eliminate[j] = update(to_eliminate[j]) hit = hit or (old != to_eliminate[j]) if hit: replacements.append((sym, subtree)) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] # remove replacements that weren't used more than once _remove_singletons(replacements, reduced_exprs) if isinstance(exprs, Matrix): reduced_exprs = [Matrix(exprs.rows, exprs.cols, reduced_exprs)] if postprocess is None: return replacements, reduced_exprs return postprocess(replacements, reduced_exprs)
def cgen_ncomp(ncomp=3, nporder=2, aggstat=False, debug=False): """Generates a C function for ncomp (int) number of components. The jth key component is always in the first position and the kth key component is always in the second. The number of enrichment stages (NP) is calculated via a taylor series approximation. The order of this approximation may be set with nporder. Only values of 1 or 2 are allowed. The aggstat argument determines whether the status messages should be aggreated and printed at the end or output as the function executes. """ start_time = time.time() stat = _aggstatus('', "generating {0} component enrichment".format(ncomp), aggstat) r = range(0, ncomp) j = 0 k = 1 # setup-symbols alpha = Symbol('alpha', positive=True, real=True) LpF = Symbol('LpF', positive=True, real=True) PpF = Symbol('PpF', positive=True, real=True) TpF = Symbol('TpF', positive=True, real=True) SWUpF = Symbol('SWUpF', positive=True, real=True) SWUpP = Symbol('SWUpP', positive=True, real=True) NP = Symbol('NP', positive=True, real=True) # Enrichment Stages NT = Symbol('NT', positive=True, real=True) # De-enrichment Stages NP0 = Symbol('NP0', positive=True, real=True) # Enrichment Stages Initial Guess NT0 = Symbol('NT0', positive=True, real=True) # De-enrichment Stages Initial Guess NP1 = Symbol('NP1', positive=True, real=True) # Enrichment Stages Computed Value NT1 = Symbol('NT1', positive=True, real=True) # De-enrichment Stages Computed Value Mstar = Symbol('Mstar', positive=True, real=True) MW = [Symbol('MW[{0}]'.format(i), positive=True, real=True) for i in r] beta = [alpha**(Mstar - MWi) for MWi in MW] # np_closed helper terms NP_b = Symbol('NP_b', real=True) NP_2a = Symbol('NP_2a', real=True) NP_sqrt_base = Symbol('NP_sqrt_base', real=True) xF = [Symbol('xF[{0}]'.format(i), positive=True, real=True) for i in r] xPi = [Symbol('xP[{0}]'.format(i), positive=True, real=True) for i in r] xTi = [Symbol('xT[{0}]'.format(i), positive=True, real=True) for i in r] xPj = Symbol('xPj', positive=True, real=True) xFj = xF[j] xTj = Symbol('xTj', positive=True, real=True) ppf = (xFj - xTj)/(xPj - xTj) tpf = (xFj - xPj)/(xTj - xPj) xP = [(((xF[i]/ppf)*(beta[i]**(NT+1) - 1))/(beta[i]**(NT+1) - beta[i]**(-NP))) \ for i in r] xT = [(((xF[i]/tpf)*(1 - beta[i]**(-NP)))/(beta[i]**(NT+1) - beta[i]**(-NP))) \ for i in r] rfeed = xFj / xF[k] rprod = xPj / xP[k] rtail = xTj / xT[k] # setup constraint equations numer = [ppf*xP[i]*log(rprod) + tpf*xT[i]*log(rtail) - xF[i]*log(rfeed) for i in r] denom = [log(beta[j]) * ((beta[i] - 1.0)/(beta[i] + 1.0)) for i in r] LoverF = sum([n/d for n, d in zip(numer, denom)]) SWUoverF = -1.0 * sum(numer) SWUoverP = SWUoverF / ppf prod_constraint = (xPj/xFj)*ppf - (beta[j]**(NT+1) - 1)/\ (beta[j]**(NT+1) - beta[j]**(-NP)) tail_constraint = (xTj/xFj)*(sum(xT)) - (1 - beta[j]**(-NP))/\ (beta[j]**(NT+1) - beta[j]**(-NP)) #xp_constraint = 1.0 - sum(xP) #xf_constraint = 1.0 - sum(xF) #xt_constraint = 1.0 - sum(xT) # This is NT(NP,...) and is correct! #nt_closed = solve(prod_constraint, NT)[0] # However, this is NT(NP,...) rewritten (by hand) to minimize the number of NP # and M* instances in the expression. Luckily this is only depends on the key # component and remains general no matter the number of components. nt_closed = (-MW[0]*log(alpha) + Mstar*log(alpha) + log(xTj) + log((-1.0 + xPj/\ xF[0])/(xPj - xTj)) - log(alpha**(NP*(MW[0] - Mstar))*(xF[0]*xPj - xPj*xTj)/\ (-xF[0]*xPj + xF[0]*xTj) + 1))/((MW[0] - Mstar)*log(alpha)) # new expression for normalized flow rate # NOTE: not needed, solved below #loverf = LoverF.xreplace({NT: nt_closed}) # Define the constraint equation with which to solve NP. This is chosen such to # minimize the number of ops in the derivatives (and thus np_closed). Other, # more verbose possibilities are commented out. #np_constraint = (xP[j]/sum(xP) - xPj).xreplace({NT: nt_closed}) #np_constraint = (xP[j]- sum(xP)*xPj).xreplace({NT: nt_closed}) #np_constraint = (xT[j]/sum(xT) - xTj).xreplace({NT: nt_closed}) np_constraint = (xT[j] - sum(xT)*xTj).xreplace({NT: nt_closed}) # get closed form approximation of NP via symbolic derivatives stat = _aggstatus(stat, " order-{0} NP approximation".format(nporder), aggstat) d0NP = np_constraint.xreplace({NP: NP0}) d1NP = diff(np_constraint, NP, 1).xreplace({NP: NP0}) if 1 == nporder: np_closed = NP0 - d1NP / d0NP elif 2 == nporder: d2NP = diff(np_constraint, NP, 2).xreplace({NP: NP0})/2.0 # taylor series polynomial coefficients, grouped by order # f(x) = ax**2 + bx + c a = d2NP b = d1NP - 2*NP0*d2NP c = d0NP - NP0*d1NP + NP0*NP0*d2NP # quadratic eq. (minus only) #np_closed = (-b - sqrt(b**2 - 4*a*c)) / (2*a) # However, we need to break up this expr as follows to prevent # a floating point arithmetic bug if b**2 - 4*a*c is very close # to zero but happens to be negative. LAME!!! np_2a = 2*a np_sqrt_base = b**2 - 4*a*c np_closed = (-NP_b - sqrt(NP_sqrt_base)) / (NP_2a) else: raise ValueError("nporder must be 1 or 2") # generate cse for writing out msg = " minimizing ops by eliminating common sub-expressions" stat = _aggstatus(stat, msg, aggstat) exprstages = [Eq(NP_b, b), Eq(NP_2a, np_2a), # fix for floating point sqrt() error Eq(NP_sqrt_base, np_sqrt_base), Eq(NP_sqrt_base, Abs(NP_sqrt_base)), Eq(NP1, np_closed), Eq(NT1, nt_closed).xreplace({NP: NP1})] cse_stages = cse(exprstages, numbered_symbols('n')) exprothers = [Eq(LpF, LoverF), Eq(PpF, ppf), Eq(TpF, tpf), Eq(SWUpF, SWUoverF), Eq(SWUpP, SWUoverP)] + \ [Eq(*z) for z in zip(xPi, xP)] + [Eq(*z) for z in zip(xTi, xT)] exprothers = [e.xreplace({NP: NP1, NT: NT1}) for e in exprothers] cse_others = cse(exprothers, numbered_symbols('g')) exprops = count_ops(exprstages + exprothers) cse_ops = count_ops(cse_stages + cse_others) msg = " reduced {0} ops to {1}".format(exprops, cse_ops) stat = _aggstatus(stat, msg, aggstat) # create function body ccode, repnames = cse_to_c(*cse_stages, indent=6, debug=debug) ccode_others, repnames_others = cse_to_c(*cse_others, indent=6, debug=debug) ccode += ccode_others repnames |= repnames_others msg = " completed in {0:.3G} s".format(time.time() - start_time) stat = _aggstatus(stat, msg, aggstat) if aggstat: print(stat) return ccode, repnames, stat
def cse(exprs, symbols=None, optimizations=None, postprocess=None, order='canonical'): """ Perform common subexpression elimination on an expression. Parameters ========== exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs The (preprocessor, postprocessor) pairs of external optimization functions. Optionally 'basic' can be passed for a set of predefined basic optimizations. Such 'basic' optimizations were used by default in old implementation, however they can be really slow on larger expressions. Now, no pre or post optimizations are made by default. postprocess : a function which accepts the two return values of cse and returns the desired form of output from cse, e.g. if you want the replacements reversed the function might be the following lambda: lambda r, e: return reversed(r), e order : string, 'none' or 'canonical' The order by which Mul and Add arguments are processed. If set to 'canonical', arguments will be canonically ordered. If set to 'none', ordering will be faster but dependent on expressions hashes, thus machine dependent and variable. For large expressions where speed is a concern, use the setting order='none'. Returns ======= replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. """ from sympy.matrices import Matrix if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) if optimizations is None: optimizations = list() elif optimizations == 'basic': optimizations = basic_optimizations # Handle the case if just one expression was passed. if isinstance(exprs, Basic): exprs = [exprs] # Preprocess the expressions to give us better optimization opportunities. reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] # Find other optimization opportunities. opt_subs = opt_cse(reduced_exprs, order) # Main CSE algorithm. replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs, order) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [ postprocess_for_cse(e, optimizations) for e in reduced_exprs ] if isinstance(exprs, Matrix): reduced_exprs = [Matrix(exprs.rows, exprs.cols, reduced_exprs)] if postprocess is None: return replacements, reduced_exprs return postprocess(replacements, reduced_exprs)
def _gauss_jordan_solve(M, B, freevar=False): """ Solves ``Ax = B`` using Gauss Jordan elimination. There may be zero, one, or infinite solutions. If one solution exists, it will be returned. If infinite solutions exist, it will be returned parametrically. If no solutions exist, It will throw ValueError. Parameters ========== B : Matrix The right hand side of the equation to be solved for. Must have the same number of rows as matrix A. freevar : List If the system is underdetermined (e.g. A has more columns than rows), infinite solutions are possible, in terms of arbitrary values of free variables. Then the index of the free variables in the solutions (column Matrix) will be returned by freevar, if the flag `freevar` is set to `True`. Returns ======= x : Matrix The matrix that will satisfy ``Ax = B``. Will have as many rows as matrix A has columns, and as many columns as matrix B. params : Matrix If the system is underdetermined (e.g. A has more columns than rows), infinite solutions are possible, in terms of arbitrary parameters. These arbitrary parameters are returned as params Matrix. Examples ======== >>> from sympy import Matrix >>> A = Matrix([[1, 2, 1, 1], [1, 2, 2, -1], [2, 4, 0, 6]]) >>> B = Matrix([7, 12, 4]) >>> sol, params = A.gauss_jordan_solve(B) >>> sol Matrix([ [-2*tau0 - 3*tau1 + 2], [ tau0], [ 2*tau1 + 5], [ tau1]]) >>> params Matrix([ [tau0], [tau1]]) >>> taus_zeroes = { tau:0 for tau in params } >>> sol_unique = sol.xreplace(taus_zeroes) >>> sol_unique Matrix([ [2], [0], [5], [0]]) >>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) >>> B = Matrix([3, 6, 9]) >>> sol, params = A.gauss_jordan_solve(B) >>> sol Matrix([ [-1], [ 2], [ 0]]) >>> params Matrix(0, 1, []) >>> A = Matrix([[2, -7], [-1, 4]]) >>> B = Matrix([[-21, 3], [12, -2]]) >>> sol, params = A.gauss_jordan_solve(B) >>> sol Matrix([ [0, -2], [3, -1]]) >>> params Matrix(0, 2, []) See Also ======== sympy.matrices.dense.DenseMatrix.lower_triangular_solve sympy.matrices.dense.DenseMatrix.upper_triangular_solve cholesky_solve diagonal_solve LDLsolve LUsolve QRsolve pinv References ========== .. [1] https://en.wikipedia.org/wiki/Gaussian_elimination """ from sympy.matrices import Matrix, zeros cls = M.__class__ aug = M.hstack(M.copy(), B.copy()) B_cols = B.cols row, col = aug[:, :-B_cols].shape # solve by reduced row echelon form A, pivots = aug.rref(simplify=True) A, v = A[:, :-B_cols], A[:, -B_cols:] pivots = list(filter(lambda p: p < col, pivots)) rank = len(pivots) # Bring to block form permutation = Matrix(range(col)).T for i, c in enumerate(pivots): permutation.col_swap(i, c) # check for existence of solutions # rank of aug Matrix should be equal to rank of coefficient matrix if not v[rank:, :].is_zero_matrix: raise ValueError("Linear system has no solution") # Get index of free symbols (free parameters) # non-pivots columns are free variables free_var_index = permutation[len(pivots):] # Free parameters # what are current unnumbered free symbol names? name = _uniquely_named_symbol( 'tau', aug, compare=lambda i: str(i).rstrip('1234567890')).name gen = numbered_symbols(name) tau = Matrix([next(gen) for k in range((col - rank) * B_cols) ]).reshape(col - rank, B_cols) # Full parametric solution V = A[:rank, [c for c in range(A.cols) if c not in pivots]] vt = v[:rank, :] free_sol = tau.vstack(vt - V * tau, tau) # Undo permutation sol = zeros(col, B_cols) for k in range(col): sol[permutation[k], :] = free_sol[k, :] sol, tau = cls(sol), cls(tau) if freevar: return sol, tau, free_var_index else: return sol, tau
def cse(exprs, symbols=None, optimizations=None, postprocess=None, order='canonical', ignore=()): """ Perform common subexpression elimination on an expression. Parameters ========== exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs The (preprocessor, postprocessor) pairs of external optimization functions. Optionally 'basic' can be passed for a set of predefined basic optimizations. Such 'basic' optimizations were used by default in old implementation, however they can be really slow on larger expressions. Now, no pre or post optimizations are made by default. postprocess : a function which accepts the two return values of cse and returns the desired form of output from cse, e.g. if you want the replacements reversed the function might be the following lambda: lambda r, e: return reversed(r), e order : string, 'none' or 'canonical' The order by which Mul and Add arguments are processed. If set to 'canonical', arguments will be canonically ordered. If set to 'none', ordering will be faster but dependent on expressions hashes, thus machine dependent and variable. For large expressions where speed is a concern, use the setting order='none'. ignore : iterable of Symbols Substitutions containing any Symbol from ``ignore`` will be ignored. Returns ======= replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. Examples ======== >>> from sympy import cse, SparseMatrix >>> from sympy.abc import x, y, z, w >>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3) ([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3]) Note that currently, y + z will not get substituted if -y - z is used. >>> cse(((w + x + y + z)*(w - y - z))/(w + x)**3) ([(x0, w + x)], [(w - y - z)*(x0 + y + z)/x0**3]) List of expressions with recursive substitutions: >>> m = SparseMatrix([x + y, x + y + z]) >>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m]) ([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([ [x0], [x1]])]) Note: the type and mutability of input matrices is retained. >>> isinstance(_[1][-1], SparseMatrix) True The user may disallow substitutions containing certain symbols: >>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,)) ([(x0, x + 1)], [x0*y**2, 3*x0*y**2]) """ from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix, SparseMatrix, ImmutableSparseMatrix) # Handle the case if just one expression was passed. if isinstance(exprs, (Basic, MatrixBase)): exprs = [exprs] copy = exprs temp = [] for e in exprs: if isinstance(e, (Matrix, ImmutableMatrix)): temp.append(Tuple(*e._mat)) elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): temp.append(Tuple(*e._smat.items())) else: temp.append(e) exprs = temp del temp if optimizations is None: optimizations = list() elif optimizations == 'basic': optimizations = basic_optimizations # Preprocess the expressions to give us better optimization opportunities. reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] if symbols is None: symbols = numbered_symbols(cls=Symbol) else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) # Find other optimization opportunities. opt_subs = opt_cse(reduced_exprs, order) # Main CSE algorithm. replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs, order, ignore) # Postprocess the expressions to return the expressions to canonical form. exprs = copy for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [ postprocess_for_cse(e, optimizations) for e in reduced_exprs ] # Get the matrices back for i, e in enumerate(exprs): if isinstance(e, (Matrix, ImmutableMatrix)): reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i]) if isinstance(e, ImmutableMatrix): reduced_exprs[i] = reduced_exprs[i].as_immutable() elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): m = SparseMatrix(e.rows, e.cols, {}) for k, v in reduced_exprs[i]: m[k] = v if isinstance(e, ImmutableSparseMatrix): m = m.as_immutable() reduced_exprs[i] = m if postprocess is None: return replacements, reduced_exprs return postprocess(replacements, reduced_exprs)
def cse(exprs, symbols=None, optimizations=None): """ Perform common subexpression elimination on an expression. Parameters: exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The `numbered_symbols` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs, optional The (preprocessor, postprocessor) pairs. If not provided, `sympy.simplify.cse.cse_optimizations` is used. Returns: replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. """ if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) seen_subexp = set() muls = set() adds = set() to_eliminate = [] to_eliminate_ops_count = [] if optimizations is None: # Pull out the default here just in case there are some weird # manipulations of the module-level list in some other thread. optimizations = list(cse_optimizations) # Handle the case if just one expression was passed. if isinstance(exprs, Basic): exprs = [exprs] # Preprocess the expressions to give us better optimization opportunities. exprs = [preprocess_for_cse(e, optimizations) for e in exprs] # Find all of the repeated subexpressions. def insert(subtree): '''This helper will insert the subtree into to_eliminate while maintaining the ordering by op count and will skip the insertion if subtree is already present.''' ops_count = subtree.count_ops() index_to_insert = bisect.bisect(to_eliminate_ops_count, ops_count) # all i up to this index have op count <= the current op count # so check that subtree is not yet present from this index down # (if necessary) to zero. for i in xrange(index_to_insert - 1, -1, -1): if to_eliminate_ops_count[i] == ops_count and \ subtree == to_eliminate[i]: return # already have it to_eliminate_ops_count.insert(index_to_insert, ops_count) to_eliminate.insert(index_to_insert, subtree) for expr in exprs: pt = preorder_traversal(expr) for subtree in pt: if subtree.is_Atom: # Exclude atoms, since there is no point in renaming them. continue if subtree in seen_subexp: insert(subtree) pt.skip() continue if subtree.is_Mul: muls.add(subtree) elif subtree.is_Add: adds.add(subtree) seen_subexp.add(subtree) # process adds - any adds that weren't repeated might contain # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common adds = [set(a.args) for a in adds] for i in xrange(len(adds)): for j in xrange(i + 1, len(adds)): com = adds[i].intersection(adds[j]) if len(com) > 1: insert(Add(*com)) # remove this set of symbols so it doesn't appear again adds[i] = adds[i].difference(com) adds[j] = adds[j].difference(com) for k in xrange(j + 1, len(adds)): if not com.difference(adds[k]): adds[k] = adds[k].difference(com) # process muls - any muls that weren't repeated might contain # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common # use SequenceMatcher on the nc part to find the longest common expression # in common between the two nc parts sm = difflib.SequenceMatcher() muls = [a.args_cnc() for a in muls] for i in xrange(len(muls)): if muls[i][1]: sm.set_seq1(muls[i][1]) for j in xrange(i + 1, len(muls)): # the commutative part in common ccom = muls[i][0].intersection(muls[j][0]) # the non-commutative part in common if muls[i][1] and muls[j][1]: # see if there is any chance of an nc match ncom = set(muls[i][1]).intersection(set(muls[j][1])) if len(ccom) + len(ncom) < 2: continue # now work harder to find the match sm.set_seq2(muls[j][1]) i1, _, n = sm.find_longest_match(0, len(muls[i][1]), 0, len(muls[j][1])) ncom = muls[i][1][i1:i1 + n] else: ncom = [] com = list(ccom) + ncom if len(com) < 2: continue insert(Mul(*com)) # remove ccom from all if there was no ncom; to update the nc part # would require finding the subexpr and then replacing it with a # dummy to keep bounding nc symbols from being identified as a # subexpr, e.g. removing B*C from A*B*C*D might allow A*D to be # identified as a subexpr which would not be right. if not ncom: muls[i][0] = muls[i][0].difference(ccom) for k in xrange(j, len(muls)): if not ccom.difference(muls[k][0]): muls[k][0] = muls[k][0].difference(ccom) # Substitute symbols for all of the repeated subexpressions. replacements = [] reduced_exprs = list(exprs) for i, subtree in enumerate(to_eliminate): sym = symbols.next() replacements.append((sym, subtree)) # Make the substitution in all of the target expressions. for j, expr in enumerate(reduced_exprs): reduced_exprs[j] = expr.subs(subtree, sym) # Make the substitution in all of the subsequent substitutions. for j in range(i+1, len(to_eliminate)): to_eliminate[j] = to_eliminate[j].subs(subtree, sym) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] return replacements, reduced_exprs
def cse(exprs, symbols=None, optimizations=None, postprocess=None, order='canonical', ignore=(), light_ignore=()): if isinstance(exprs, (Basic, MatrixBase)): exprs = [exprs] copy = exprs temp = [] for e in exprs: if isinstance(e, (Matrix, ImmutableMatrix)): temp.append(Tuple(*e._mat)) elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): temp.append(Tuple(*e._smat.items())) else: temp.append(e) exprs = temp del temp if optimizations is None: optimizations = list() elif optimizations == 'basic': optimizations = basic_optimizations # Preprocess the expressions to give us better optimization opportunities. reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] if symbols is None: symbols = numbered_symbols(cls=Symbol) else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) # Find other optimization opportunities. opt_subs = opt_cse(reduced_exprs, order) # Main CSE algorithm. replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs, order, ignore, light_ignore) # Postprocess the expressions to return the expressions to canonical form. exprs = copy for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [ postprocess_for_cse(e, optimizations) for e in reduced_exprs ] # Get the matrices back for i, e in enumerate(exprs): if isinstance(e, (Matrix, ImmutableMatrix)): reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i]) if isinstance(e, ImmutableMatrix): reduced_exprs[i] = reduced_exprs[i].as_immutable() elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): m = SparseMatrix(e.rows, e.cols, {}) for k, v in reduced_exprs[i]: m[k] = v if isinstance(e, ImmutableSparseMatrix): m = m.as_immutable() reduced_exprs[i] = m if postprocess is None: return replacements, reduced_exprs return postprocess(replacements, reduced_exprs)
def cse(exprs, symbols=None, optimizations=None): """ Perform common subexpression elimination on an expression. Parameters: exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The `numbered_symbols` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs, optional The (preprocessor, postprocessor) pairs. If not provided, `sympy.simplify.cse.cse_optimizations` is used. Returns: replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. """ if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) seen_subexp = set() to_eliminate = [] if optimizations is None: # Pull out the default here just in case there are some weird # manipulations of the module-level list in some other thread. optimizations = list(cse_optimizations) # Handle the case if just one expression was passed. if isinstance(exprs, Basic): exprs = [exprs] # Preprocess the expressions to give us better optimization opportunities. exprs = [preprocess_for_cse(e, optimizations) for e in exprs] # Find all of the repeated subexpressions. for expr in exprs: for subtree in postorder_traversal(expr): if subtree.args == (): # Exclude atoms, since there is no point in renaming them. continue if (subtree.args != () and subtree in seen_subexp and subtree not in to_eliminate): to_eliminate.append(subtree) seen_subexp.add(subtree) # Substitute symbols for all of the repeated subexpressions. replacements = [] reduced_exprs = list(exprs) for i, subtree in enumerate(to_eliminate): sym = symbols.next() replacements.append((sym, subtree)) # Make the substitution in all of the target expressions. for j, expr in enumerate(reduced_exprs): reduced_exprs[j] = expr.subs(subtree, sym) # Make the substitution in all of the subsequent substitutions. # WARNING: modifying iterated list in-place! I think it's fine, # but there might be clearer alternatives. for j in range(i+1, len(to_eliminate)): to_eliminate[j] = to_eliminate[j].subs(subtree, sym) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] return replacements, reduced_exprs
def test_numbered_symbols(): s = numbered_symbols(cls=Dummy) assert isinstance(s.next(), Dummy)
def cse(exprs, symbols=None, optimizations=None, postprocess=None, order='canonical'): """ Perform common subexpression elimination on an expression. Parameters ========== exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs The (preprocessor, postprocessor) pairs of external optimization functions. Optionally 'basic' can be passed for a set of predefined basic optimizations. Such 'basic' optimizations were used by default in old implementation, however they can be really slow on larger expressions. Now, no pre or post optimizations are made by default. postprocess : a function which accepts the two return values of cse and returns the desired form of output from cse, e.g. if you want the replacements reversed the function might be the following lambda: lambda r, e: return reversed(r), e order : string, 'none' or 'canonical' The order by which Mul and Add arguments are processed. If set to 'canonical', arguments will be canonically ordered. If set to 'none', ordering will be faster but dependent on expressions hashes, thus machine dependent and variable. For large expressions where speed is a concern, use the setting order='none'. Returns ======= replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. """ from sympy.matrices import Matrix # Handle the case if just one expression was passed. if isinstance(exprs, Basic): exprs = [exprs] if optimizations is None: optimizations = list() elif optimizations == 'basic': optimizations = basic_optimizations # Preprocess the expressions to give us better optimization opportunities. reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] excluded_symbols = set.union(*[expr.atoms(Symbol) for expr in reduced_exprs]) if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) symbols = filter_symbols(symbols, excluded_symbols) # Find other optimization opportunities. opt_subs = opt_cse(reduced_exprs, order) # Main CSE algorithm. replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs, order) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] if isinstance(exprs, Matrix): reduced_exprs = [Matrix(exprs.rows, exprs.cols, reduced_exprs)] if postprocess is None: return replacements, reduced_exprs return postprocess(replacements, reduced_exprs)
def rsolve_hyper(coeffs, f, n, **hints): """ Given linear recurrence operator `\operatorname{L}` of order `k` with polynomial coefficients and inhomogeneous equation `\operatorname{L} y = f` we seek for all hypergeometric solutions over field `K` of characteristic zero. The inhomogeneous part can be either hypergeometric or a sum of a fixed number of pairwise dissimilar hypergeometric terms. The algorithm performs three basic steps: (1) Group together similar hypergeometric terms in the inhomogeneous part of `\operatorname{L} y = f`, and find particular solution using Abramov's algorithm. (2) Compute generating set of `\operatorname{L}` and find basis in it, so that all solutions are linearly independent. (3) Form final solution with the number of arbitrary constants equal to dimension of basis of `\operatorname{L}`. Term `a(n)` is hypergeometric if it is annihilated by first order linear difference equations with polynomial coefficients or, in simpler words, if consecutive term ratio is a rational function. The output of this procedure is a linear combination of fixed number of hypergeometric terms. However the underlying method can generate larger class of solutions - D'Alembertian terms. Note also that this method not only computes the kernel of the inhomogeneous equation, but also reduces in to a basis so that solutions generated by this procedure are linearly independent Examples ======== >>> from sympy.solvers import rsolve_hyper >>> from sympy.abc import x >>> rsolve_hyper([-1, -1, 1], 0, x) C0*(1/2 + sqrt(5)/2)**x + C1*(-sqrt(5)/2 + 1/2)**x >>> rsolve_hyper([-1, 1], 1 + x, x) C0 + x*(x + 1)/2 References ========== .. [1] M. Petkovsek, Hypergeometric solutions of linear recurrences with polynomial coefficients, J. Symbolic Computation, 14 (1992), 243-264. .. [2] M. Petkovsek, H. S. Wilf, D. Zeilberger, A = B, 1996. """ coeffs = map(sympify, coeffs) f = sympify(f) r, kernel, symbols = len(coeffs) - 1, [], set() if not f.is_zero: if f.is_Add: similar = {} for g in f.expand().args: if not g.is_hypergeometric(n): return None for h in similar.iterkeys(): if hypersimilar(g, h, n): similar[h] += g break else: similar[g] = S.Zero inhomogeneous = [] for g, h in similar.iteritems(): inhomogeneous.append(g + h) elif f.is_hypergeometric(n): inhomogeneous = [f] else: return None for i, g in enumerate(inhomogeneous): coeff, polys = S.One, coeffs[:] denoms = [ S.One ] * (r + 1) s = hypersimp(g, n) for j in xrange(1, r + 1): coeff *= s.subs(n, n + j - 1) p, q = coeff.as_numer_denom() polys[j] *= p denoms[j] = q for j in xrange(0, r + 1): polys[j] *= Mul(*(denoms[:j] + denoms[j + 1:])) R = rsolve_poly(polys, Mul(*denoms), n) if not (R is None or R is S.Zero): inhomogeneous[i] *= R else: return None result = Add(*inhomogeneous) else: result = S.Zero Z = Dummy('Z') p, q = coeffs[0], coeffs[r].subs(n, n - r + 1) p_factors = [ z for z in roots(p, n).iterkeys() ] q_factors = [ z for z in roots(q, n).iterkeys() ] factors = [ (S.One, S.One) ] for p in p_factors: for q in q_factors: if p.is_integer and q.is_integer and p <= q: continue else: factors += [(n - p, n - q)] p = [ (n - p, S.One) for p in p_factors ] q = [ (S.One, n - q) for q in q_factors ] factors = p + factors + q for A, B in factors: polys, degrees = [], [] D = A*B.subs(n, n + r - 1) for i in xrange(0, r + 1): a = Mul(*[ A.subs(n, n + j) for j in xrange(0, i) ]) b = Mul(*[ B.subs(n, n + j) for j in xrange(i, r) ]) poly = quo(coeffs[i]*a*b, D, n) polys.append(poly.as_poly(n)) if not poly.is_zero: degrees.append(polys[i].degree()) d, poly = max(degrees), S.Zero for i in xrange(0, r + 1): coeff = polys[i].nth(d) if coeff is not S.Zero: poly += coeff * Z**i for z in roots(poly, Z).iterkeys(): if z.is_zero: continue (C, s) = rsolve_poly([ polys[i]*z**i for i in xrange(r + 1) ], 0, n, symbols=True) if C is not None and C is not S.Zero: symbols |= set(s) ratio = z * A * C.subs(n, n + 1) / B / C ratio = simplify(ratio) # If there is a nonnegative root in the denominator of the ratio, # this indicates that the term y(n_root) is zero, and one should # start the product with the term y(n_root + 1). n0 = 0 for n_root in roots(ratio.as_numer_denom()[1], n).keys(): if (n0 < (n_root + 1)) is True: n0 = n_root + 1 K = product(ratio, (n, n0, n - 1)) if K.has(factorial, FallingFactorial, RisingFactorial): K = simplify(K) if casoratian(kernel + [K], n, zero=False) != 0: kernel.append(K) kernel.sort(key=default_sort_key) sk = zip(numbered_symbols('C'), kernel) if sk: for C, ker in sk: result += C * ker else: return None if hints.get('symbols', False): symbols |= set([s for s, k in sk]) return (result, list(symbols)) else: return result
def test_filter_symbols(): s = numbered_symbols() filtered = filter_symbols(s, symbols("x0 x2 x3")) assert take(filtered, 3) == list(symbols("x1 x4 x5"))
def cse(exprs, symbols=None, optimizations=None): """ Perform common subexpression elimination on an expression. Parameters: exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs, optional The (preprocessor, postprocessor) pairs. If not provided, ``sympy.simplify.cse.cse_optimizations`` is used. Returns: replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. """ if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) seen_subexp = set() muls = set() adds = set() to_eliminate = [] to_eliminate_ops_count = [] if optimizations is None: # Pull out the default here just in case there are some weird # manipulations of the module-level list in some other thread. optimizations = list(cse_optimizations) # Handle the case if just one expression was passed. if isinstance(exprs, Basic): exprs = [exprs] # Preprocess the expressions to give us better optimization opportunities. exprs = [preprocess_for_cse(e, optimizations) for e in exprs] # Find all of the repeated subexpressions. def insert(subtree): '''This helper will insert the subtree into to_eliminate while maintaining the ordering by op count and will skip the insertion if subtree is already present.''' ops_count = subtree.count_ops() index_to_insert = bisect.bisect(to_eliminate_ops_count, ops_count) # all i up to this index have op count <= the current op count # so check that subtree is not yet present from this index down # (if necessary) to zero. for i in xrange(index_to_insert - 1, -1, -1): if to_eliminate_ops_count[i] == ops_count and \ subtree == to_eliminate[i]: return # already have it to_eliminate_ops_count.insert(index_to_insert, ops_count) to_eliminate.insert(index_to_insert, subtree) for expr in exprs: pt = preorder_traversal(expr) for subtree in pt: if subtree.is_Atom: # Exclude atoms, since there is no point in renaming them. continue if subtree in seen_subexp: insert(subtree) pt.skip() continue if subtree.is_Mul: muls.add(subtree) elif subtree.is_Add: adds.add(subtree) seen_subexp.add(subtree) # process adds - any adds that weren't repeated might contain # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common adds = [set(a.args) for a in adds] for i in xrange(len(adds)): for j in xrange(i + 1, len(adds)): com = adds[i].intersection(adds[j]) if len(com) > 1: insert(Add(*com)) # remove this set of symbols so it doesn't appear again adds[i] = adds[i].difference(com) adds[j] = adds[j].difference(com) for k in xrange(j + 1, len(adds)): if not com.difference(adds[k]): adds[k] = adds[k].difference(com) # process muls - any muls that weren't repeated might contain # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common # use SequenceMatcher on the nc part to find the longest common expression # in common between the two nc parts sm = difflib.SequenceMatcher() muls = [a.args_cnc() for a in muls] for i in xrange(len(muls)): if muls[i][1]: sm.set_seq1(muls[i][1]) for j in xrange(i + 1, len(muls)): # the commutative part in common ccom = muls[i][0].intersection(muls[j][0]) # the non-commutative part in common if muls[i][1] and muls[j][1]: # see if there is any chance of an nc match ncom = set(muls[i][1]).intersection(set(muls[j][1])) if len(ccom) + len(ncom) < 2: continue # now work harder to find the match sm.set_seq2(muls[j][1]) i1, _, n = sm.find_longest_match(0, len(muls[i][1]), 0, len(muls[j][1])) ncom = muls[i][1][i1:i1 + n] else: ncom = [] com = list(ccom) + ncom if len(com) < 2: continue insert(Mul(*com)) # remove ccom from all if there was no ncom; to update the nc part # would require finding the subexpr and then replacing it with a # dummy to keep bounding nc symbols from being identified as a # subexpr, e.g. removing B*C from A*B*C*D might allow A*D to be # identified as a subexpr which would not be right. if not ncom: muls[i][0] = muls[i][0].difference(ccom) for k in xrange(j, len(muls)): if not ccom.difference(muls[k][0]): muls[k][0] = muls[k][0].difference(ccom) # Substitute symbols for all of the repeated subexpressions. replacements = [] reduced_exprs = list(exprs) for i, subtree in enumerate(to_eliminate): sym = symbols.next() replacements.append((sym, subtree)) # Make the substitution in all of the target expressions. for j, expr in enumerate(reduced_exprs): reduced_exprs[j] = expr.subs(subtree, sym) # Make the substitution in all of the subsequent substitutions. for j in range(i+1, len(to_eliminate)): to_eliminate[j] = to_eliminate[j].subs(subtree, sym) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] return replacements, reduced_exprs
def _eval_rewrite_as_sqrt(self, arg): _EXPAND_INTS = False def migcdex(x): # recursive calcuation of gcd and linear combination # for a sequence of integers. # Given (x1, x2, x3) # Returns (y1, y1, y3, g) # such that g is the gcd and x1*y1+x2*y2+x3*y3 - g = 0 # Note, that this is only one such linear combination. if len(x) == 1: return (1, x[0]) if len(x) == 2: return igcdex(x[0], x[-1]) g = migcdex(x[1:]) u, v, h = igcdex(x[0], g[-1]) return tuple([u] + [v*i for i in g[0:-1] ] + [h]) def ipartfrac(r, factors=None): if isinstance(r, int): return r assert isinstance(r, C.Rational) n = r.q if 2 > r.q*r.q: return r.q if None == factors: a = [n/x**y for x, y in factorint(r.q).iteritems()] else: a = [n/x for x in factors] if len(a) == 1: return [ r ] h = migcdex(a) ans = [ r.p*C.Rational(i*j, r.q) for i, j in zip(h[:-1], a) ] assert r == sum(ans) return ans pi_coeff = _pi_coeff(arg) if pi_coeff is None: return None assert not pi_coeff.is_integer, "should have been simplified already" if not pi_coeff.is_Rational: return None cst_table_some = { 3: S.Half, 5: (sqrt(5) + 1)/4, 17: sqrt((15 + sqrt(17))/32 + sqrt(2)*(sqrt(17 - sqrt(17)) + sqrt(sqrt(2)*(-8*sqrt(17 + sqrt(17)) - (1 - sqrt(17)) *sqrt(17 - sqrt(17))) + 6*sqrt(17) + 34))/32) # 65537 and 257 are the only other known Fermat primes # Please add if you would like them } def fermatCoords(n): assert isinstance(n, int) assert n > 0 if n == 1 or 0 == n % 2: return False primes = dict( [(p, 0) for p in cst_table_some ] ) assert 1 not in primes for p_i in primes: while 0 == n % p_i: n = n/p_i primes[p_i] += 1 if 1 != n: return False if max(primes.values()) > 1: return False return tuple([ p for p in primes if primes[p] == 1]) if pi_coeff.q in cst_table_some: return C.chebyshevt(pi_coeff.p, cst_table_some[pi_coeff.q]).expand() if 0 == pi_coeff.q % 2: # recursively remove powers of 2 narg = (pi_coeff*2)*S.Pi nval = cos(narg) if None == nval: return None nval = nval.rewrite(sqrt) if not _EXPAND_INTS: if (isinstance(nval, cos) or isinstance(-nval, cos)): return None x = (2*pi_coeff + 1)/2 sign_cos = (-1)**((-1 if x < 0 else 1)*int(abs(x))) return sign_cos*sqrt( (1 + nval)/2 ) FC = fermatCoords(pi_coeff.q) if FC: decomp = ipartfrac(pi_coeff, FC) X = [(x[1], x[0]*S.Pi) for x in zip(decomp, numbered_symbols('z'))] pcls = cos(sum([x[0] for x in X]))._eval_expand_trig().subs(X) return pcls.rewrite(sqrt) if _EXPAND_INTS: decomp = ipartfrac(pi_coeff) X = [(x[1], x[0]*S.Pi) for x in zip(decomp, numbered_symbols('z'))] pcls = cos(sum([x[0] for x in X]))._eval_expand_trig().subs(X) return pcls return None
def __init__(self, commands): self.commands = commands self.get_symbol = numbered_symbols("tmp")