def visit(expr): args = list(expr.args) if len(args) == 0: return expr if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul): for i, arg in enumerate(args): if is_constant(arg) and abs(arg) != 1: if arg < 0: args[i] = -constants_to_subexp_dict[-arg] else: args[i] = constants_to_subexp_dict[arg] return expr.func(*(visit(a) for a in args))
def to_placeholder_function(expr, name): """Replaces an expression by a sympy function. - replacing an expression with just a symbol would lead to problem when calculating derivatives - placeholder functions get rid of this problem Examples: >>> x, t = sp.symbols("x, t") >>> temperature = x**2 + t**4 # some 'complicated' dependency >>> temperature_placeholder = to_placeholder_function(temperature, 'T') >>> diffusivity = temperature_placeholder + 42 * t >>> sp.diff(diffusivity, t) # returns a symbol instead of the computed derivative _dT_dt + 42 >>> result, subexpr = remove_placeholder_functions(diffusivity) >>> result T + 42*t >>> subexpr [Assignment(T, t**4 + x**2), Assignment(_dT_dt, 4*t**3), Assignment(_dT_dx, 2*x)] """ symbols = list(expr.atoms(sp.Symbol)) symbols.sort(key=lambda e: e.name) derivative_symbols = [ sp.Symbol("_d{}_d{}".format(name, s.name)) for s in symbols ] derivatives = [sp.diff(expr, s) for s in symbols] assignments = [Assignment(sp.Symbol(name), expr)] assignments += [ Assignment(symbol, derivative) for symbol, derivative in zip(derivative_symbols, derivatives) if not is_constant(derivative) ] def fdiff(_, index): result = derivatives[index - 1] return result if is_constant(result) else derivative_symbols[index - 1] func = type( name, (sp.Function, PlaceholderFunction), { 'fdiff': fdiff, 'value': sp.Symbol(name), 'subexpressions': assignments, 'nargs': len(symbols) }) return func(*symbols)
def fdiff(_, index): result = derivatives[index - 1] return result if is_constant(result) else derivative_symbols[index - 1]
def insert_constants(ac, **kwargs): """Inserts subexpressions whose right-hand side is constant, i.e. contains no symbols.""" return insert_subexpressions(ac, lambda x: is_constant(x.rhs), **kwargs)