示例#1
0
 def visit(expr):
     args = list(expr.args)
     if len(args) == 0:
         return expr
     if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul):
         for i, arg in enumerate(args):
             if is_constant(arg) and abs(arg) != 1:
                 if arg < 0:
                     args[i] = -constants_to_subexp_dict[-arg]
                 else:
                     args[i] = constants_to_subexp_dict[arg]
     return expr.func(*(visit(a) for a in args))
def to_placeholder_function(expr, name):
    """Replaces an expression by a sympy function.

    - replacing an expression with just a symbol would lead to problem when calculating derivatives
    - placeholder functions get rid of this problem

    Examples:
        >>> x, t = sp.symbols("x, t")
        >>> temperature = x**2 + t**4 # some 'complicated' dependency
        >>> temperature_placeholder = to_placeholder_function(temperature, 'T')
        >>> diffusivity = temperature_placeholder + 42 * t
        >>> sp.diff(diffusivity, t)  # returns a symbol instead of the computed derivative
        _dT_dt + 42
        >>> result, subexpr = remove_placeholder_functions(diffusivity)
        >>> result
        T + 42*t
        >>> subexpr
        [Assignment(T, t**4 + x**2), Assignment(_dT_dt, 4*t**3), Assignment(_dT_dx, 2*x)]

    """
    symbols = list(expr.atoms(sp.Symbol))
    symbols.sort(key=lambda e: e.name)
    derivative_symbols = [
        sp.Symbol("_d{}_d{}".format(name, s.name)) for s in symbols
    ]
    derivatives = [sp.diff(expr, s) for s in symbols]

    assignments = [Assignment(sp.Symbol(name), expr)]
    assignments += [
        Assignment(symbol, derivative)
        for symbol, derivative in zip(derivative_symbols, derivatives)
        if not is_constant(derivative)
    ]

    def fdiff(_, index):
        result = derivatives[index - 1]
        return result if is_constant(result) else derivative_symbols[index - 1]

    func = type(
        name, (sp.Function, PlaceholderFunction), {
            'fdiff': fdiff,
            'value': sp.Symbol(name),
            'subexpressions': assignments,
            'nargs': len(symbols)
        })
    return func(*symbols)
示例#3
0
 def fdiff(_, index):
     result = derivatives[index - 1]
     return result if is_constant(result) else derivative_symbols[index - 1]
def insert_constants(ac, **kwargs):
    """Inserts subexpressions whose right-hand side is constant, 
    i.e. contains no symbols."""
    return insert_subexpressions(ac, lambda x: is_constant(x.rhs), **kwargs)