def _sub_for_func_ast(ast, func_name, func_vars, func_expr_ast): """ Return an ast with the function func_name substituted out. """ if isinstance(ast, Call) and ast2str(ast.func) == func_name\ and func_vars == '*': working_ast = copy.deepcopy(func_expr_ast) new_args = [_sub_for_func_ast(arg_ast, func_name, func_vars, func_expr_ast) for arg_ast in ast.args] # This subs out the arguments of the original function. working_ast.values = new_args return working_ast if isinstance(ast, Call) and ast2str(ast.func) == func_name\ and len(ast.args) == len(func_vars): # If our ast is the function we're looking for, we take the ast # for the function expression, substitute for its arguments, and # return working_ast = copy.deepcopy(func_expr_ast) mapping = {} for var_name, arg_ast in zip(func_vars, ast.args): subbed_arg_ast = _sub_for_func_ast(arg_ast, func_name, func_vars, func_expr_ast) mapping[var_name] = subbed_arg_ast _sub_subtrees_for_vars(working_ast, mapping) return working_ast ast = AST.recurse_down_tree(ast, _sub_for_func_ast, (func_name, func_vars, func_expr_ast,)) return ast
def _sub_subtrees_for_vars(ast, ast_mappings): """ For each out_name, in_ast pair in mappings, substitute in_ast for all occurances of the variable named out_name in ast """ if isinstance(ast, Name) and ast2str(ast) in ast_mappings: return ast_mappings[ast2str(ast)] ast = AST.recurse_down_tree(ast, _sub_subtrees_for_vars, (ast_mappings,)) return ast
def sub_for_comps(expr, mapping): """ For each pair out_name:in_expr in mapping, the returned string has all occurences of the variable out_compe substituted by in_expr. """ if len(mapping) == 0: return expr ast = strip_parse(expr) ast_mapping = {} for out_expr, in_expr in list(mapping.items()): out_ast = strip_parse(out_expr) if not isinstance(out_ast, Compare): raise ValueError('Expression %s to substitute for is not a '\ 'comparison.' % out_expr) ast_mapping[ast2str(out_ast)] = strip_parse(in_expr) ast = _sub_subtrees_for_comps(ast, ast_mapping) return ast2str(ast)
def test__collect_num_denom(self): cases = [(strip_parse('x-x'), (['x - x'], [])), (strip_parse('1/2'), (['1'], ['2'])), (strip_parse('1/2*3'), (['1', '3'], ['2'])), (strip_parse('1/(2*3)'), (['1'], ['2', '3'])), (strip_parse('1/(2/3)'), (['1', '3'], ['2'])), (strip_parse('(1*2)*(3*4)'), (['1', '2', '3', '4'], [])), (strip_parse('(1*2)*(3/4)'), (['1', '2', '3'], ['4'])), (strip_parse('(1/2)*(3/4)'), (['1', '3'], ['2', '4'])), (strip_parse('(1/2)/(3/4)'), (['1', '4'], ['2', '3'])), ] for ast, (nums, denoms) in cases: n, d = [], [] AST._collect_num_denom(ast, n, d) n = [ast2str(term) for term in n] d = [ast2str(term) for term in d] assert set(nums) == set(n) assert set(denoms) == set(d)
def make_c_compatible(expr): """ Convert a python math string into one compatible with C. Substitute all python-style x**n exponents with pow(x, n). Replace all integer constants with float values to avoid integer casting problems (e.g. '1' -> '1.0'). Replace 'and', 'or', and 'not' with C's '&&', '||', and '!'. This may be fragile if the parsing library changes in newer python versions. """ ast = strip_parse(expr) ast = _make_c_compatible_ast(ast) return ast2str(ast)
def test__collect_pos_neg(self): cases = [(strip_parse('-y + z'), (['z'], ['y'])), (strip_parse('1-2'), (['1'], ['2'])), (strip_parse('1-2+3'), (['1', '3'], ['2'])), (strip_parse('1-(2+3)'), (['1'], ['2', '3'])), (strip_parse('1-(2-3)'), (['1', '3'], ['2'])), (strip_parse('(1-2)-(3-4)'), (['1', '4'], ['2', '3'])), (strip_parse('(1+2)+(3+4)'), (['1', '2', '3', '4'], [])), (strip_parse('(1+2)+(3-4)'), (['1', '2', '3'], ['4'])), (strip_parse('(1-2)+(3-4)'), (['1', '3'], ['2', '4'])), (strip_parse('(1-2)-(3-4)'), (['1', '4'], ['2', '3'])), ] for ast, (poss, negs) in cases: p, n = [], [] AST._collect_pos_neg(ast, p, n) p = [ast2str(term) for term in p] n = [ast2str(term) for term in n] assert set(poss) == set(p) assert set(negs) == set(n)
def sub_for_func(expr, func_name, func_vars, func_expr): """ Return a string with the function func_name substituted for its exploded form. func_name: The name of the function. func_vars: A sequence variables used by the function expression func_expr: The expression for the function. For example: If f(x, y, z) = sqrt(z)*x*y-z func_name = 'f' func_vars = ['x', 'y', 'z'] func_expr = 'sqrt(z)*x*y-z' As a special case, functions that take a variable number of arguments can use '*' for func_vars. For example: sub_for_func('or_func(or_func(A,D),B,C)', 'or_func', '*', 'x or y') yields '(A or D) or B or C' """ ast = strip_parse(expr) func_name_ast = strip_parse(func_name) if not isinstance(func_name_ast, Name): raise ValueError('Function name is not a simple name.') func_name = func_name_ast.id func_expr_ast = strip_parse(func_expr) # We can strip_parse the '*', so we special case it here. if func_vars == '*': if not hasattr(func_expr_ast, 'values'): raise ValueError("Top-level function in %s does not appear to " "accept variable number of arguments. (It has no " "'nodes' attribute.)" % func_expr) func_var_names = '*' else: func_vars_ast = [strip_parse(var) for var in func_vars] for var_ast in func_vars_ast: if not isinstance(var_ast, Name): raise ValueError('Function variable is not a simple name.') func_var_names = [getattr(var_ast, 'id') for var_ast in func_vars_ast] ast = _sub_for_func_ast(ast, func_name, func_var_names, func_expr_ast) simple = Simplify._simplify_ast(ast) return ast2str(simple)
def test_ast2str(self): cases = ['x', 'x+y', 'x-y', 'x*y', 'x/y', 'x**y', '-x', 'x**-y', 'x**(-y + z)', 'f(x)', 'g(x,y,z)', 'x**(y**z)', '(x**y)**z', 'x**y**z', 'x - (x+y)', '(x+y) - z', 'g(x-0+2, y**2 - 0**0, z*y + x/1)', 'x/x', 'x/y', '(x-x)/z', 'x**2 - y/z', 'x+1-1+2-3-x', '0+1*1', 'x-x+y', '(-2)**2', '-2**2', 'x < 0.5', 'x + y < 0.8', 'x > 0.5', 'x+y < z - x', '(f(x) < 0.5) == False', 'x == x', 'x + y == 2*x - x + y', 'True and False', 'True and False or True', 'True or False', '(True and not False) or True', 'not (True and False)', 'x == x and y == y', 'x - x == 0 or y - x != 0'] for expr in cases: run = ast2str(strip_parse(expr)) orig = eval(expr) out = eval(run) if orig != 0: assert old_div(abs(orig - out),(0.5 * (orig + out))) < 1e-6 else: assert out == 0
def _make_c_compatible_ast(ast): if isinstance(ast, BinOp) and isinstance(ast.op, Pow): ast = Call(func=Name(id='pow', ctx=Load()), args=[ast.left, ast.right], keywords=[]) ast = AST.recurse_down_tree(ast, _make_c_compatible_ast) elif isinstance(ast, Constant) and isinstance(ast.value, int): ast.value = float(ast.value) elif isinstance(ast, Subscript): # These asts correspond to array[blah] and we shouldn't convert these # to floats, so we don't recurse down the tree in this case. pass # We need to subsitute the C logical operators. Unfortunately, they aren't # valid python syntax, so we have to cheat a little, using Compare and Name # nodes abusively. This abuse may not be future-proof... sigh... elif isinstance(ast, BoolOp) and isinstance(ast.op, And): nodes = AST.recurse_down_tree(ast.values, _make_c_compatible_ast) ops = [('&&', node) for node in nodes[1:]] ops_1 = [] c = [] for k, v in ops: ops_1.append(k) c.append(v) ast = Compare(nodes[0], ops_1, c) elif isinstance(ast, BoolOp) and isinstance(ast.op, Or): nodes = AST.recurse_down_tree(ast.values, _make_c_compatible_ast) ops = [('||', node) for node in nodes[1:]] ops_1 = [] c = [] for k, v in ops: ops_1.append(k) c.append(v) ast = AST.Compare(nodes[0], ops_1, c) elif isinstance(ast, UnaryOp) and isinstance(ast.op, Not): expr = AST.recurse_down_tree(ast.operand, _make_c_compatible_ast) ast = AST.Name(id='!(%s)' % ast2str(expr)) else: ast = AST.recurse_down_tree(ast, _make_c_compatible_ast) return ast
def _sub_subtrees_for_comps(ast, ast_mappings): if isinstance(ast, Compare) and ast2str(ast) in ast_mappings: return ast_mappings[ast2str(ast)] ast = AST.recurse_down_tree(ast, _sub_subtrees_for_comps, (ast_mappings,)) return ast