def sub_for_func(expr, func_name, func_vars, func_expr): """ Return a string with the function func_name substituted for its exploded form. func_name: The name of the function. func_vars: A sequence variables used by the function expression func_expr: The expression for the function. For example: If f(x, y, z) = sqrt(z)*x*y-z func_name = 'f' func_vars = ['x', 'y', 'z'] func_expr = 'sqrt(z)*x*y-z' As a special case, functions that take a variable number of arguments can use '*' for func_vars. For example: sub_for_func('or_func(or_func(A,D),B,C)', 'or_func', '*', 'x or y') yields '(A or D) or B or C' """ ast = strip_parse(expr) func_name_ast = strip_parse(func_name) if not isinstance(func_name_ast, Name): raise ValueError('Function name is not a simple name.') func_name = func_name_ast.name func_expr_ast = strip_parse(func_expr) # We can strip_parse the '*', so we special case it here. if func_vars == '*': if not hasattr(func_expr_ast, 'nodes'): raise ValueError("Top-level function in %s does not appear to " "accept variable number of arguments. (It has no " "'nodes' attribute.)" % func_expr) func_var_names = '*' else: func_vars_ast = [strip_parse(var) for var in func_vars] for var_ast in func_vars_ast: if not isinstance(var_ast, Name): raise ValueError('Function variable is not a simple name.') func_var_names = [ getattr(var_ast, 'name') for var_ast in func_vars_ast ] ast = _sub_for_func_ast(ast, func_name, func_var_names, func_expr_ast) simple = Simplify._simplify_ast(ast) return ast2str(simple)
def sub_for_comps(expr, mapping): """ For each pair out_name:in_expr in mapping, the returned string has all occurences of the variable out_compe substituted by in_expr. """ if len(mapping) == 0: return expr ast = strip_parse(expr) ast_mapping = {} for out_expr, in_expr in mapping.items(): out_ast = strip_parse(out_expr) if not isinstance(out_ast, Compare): raise ValueError, 'Expression %s to substitute for is not a '\ 'comparison.' % out_expr ast_mapping[ast2str(out_ast)] = strip_parse(in_expr) ast = _sub_subtrees_for_comps(ast, ast_mapping) return ast2str(ast)
def make_c_compatible(expr): """ Convert a python math string into one compatible with C. Substitute all python-style x**n exponents with pow(x, n). Replace all integer constants with float values to avoid integer casting problems (e.g. '1' -> '1.0'). Replace 'and', 'or', and 'not' with C's '&&', '||', and '!'. This may be fragile if the parsing library changes in newer python versions. """ ast = strip_parse(expr) ast = _make_c_compatible_ast(ast) return ast2str(ast)
('log', 1): ('1/arg0', ), ('log10', 1): ('1/(log(10)*arg0)', ), ('sin', 1): ('cos(arg0)', ), ('sinh', 1): ('cosh(arg0)', ), ('arcsinh', 1): ('1/sqrt(1+arg0**2)', ), ('arccosh', 1): ('1/sqrt(arg0**2 - 1.)', ), ('arctanh', 1): ('1/(1.-arg0**2)', ), ('sqrt', 1): ('1/(2*sqrt(arg0))', ), ('tan', 1): ('1/cos(arg0)**2', ), ('tanh', 1): ('1/cosh(arg0)**2', ), ('pow', 2): ('arg1 * arg0**(arg1-1)', 'log(arg0) * arg0**arg1'), ('min', 2): ('arg0<=arg1', 'arg0>arg1'), ('max', 2): ('arg0>=arg1', 'arg0<arg1') } for key, terms in _KNOWN_FUNCS.items(): _KNOWN_FUNCS[key] = [strip_parse(term) for term in terms] def _diff_ast(ast, wrt): """ Return an AST that is the derivative of ast with respect the variable with name 'wrt'. """ # For now, the strategy is to return the most general forms, and let # the simplifier take care of the special cases. if isinstance(ast, Name): if ast.name == wrt: return _ONE else: return _ZERO