Ejemplo n.º 1
0
def _sub_for_func_ast(ast, func_name, func_vars, func_expr_ast):
    """
    Return an ast with the function func_name substituted out.
    """
    if isinstance(ast, Call) and ast2str(ast.func) == func_name\
       and func_vars == '*':
        working_ast = copy.deepcopy(func_expr_ast)
        new_args = [_sub_for_func_ast(arg_ast, func_name, func_vars, 
                                      func_expr_ast) for arg_ast in ast.args]
        # This subs out the arguments of the original function.
        working_ast.values = new_args
        return working_ast
    if isinstance(ast, Call) and ast2str(ast.func) == func_name\
       and len(ast.args) == len(func_vars):
        # If our ast is the function we're looking for, we take the ast
        #  for the function expression, substitute for its arguments, and
        #  return
        working_ast = copy.deepcopy(func_expr_ast)
        mapping = {}
        for var_name, arg_ast in zip(func_vars, ast.args):
            subbed_arg_ast = _sub_for_func_ast(arg_ast, func_name, func_vars, 
                                               func_expr_ast)
            mapping[var_name] = subbed_arg_ast
        _sub_subtrees_for_vars(working_ast, mapping)
        return working_ast
    ast = AST.recurse_down_tree(ast, _sub_for_func_ast, 
                                (func_name, func_vars, func_expr_ast,))
    return ast
Ejemplo n.º 2
0
def _sub_subtrees_for_vars(ast, ast_mappings):
    """
    For each out_name, in_ast pair in mappings, substitute in_ast for all 
    occurances of the variable named out_name in ast
    """
    if isinstance(ast, Name) and ast2str(ast) in ast_mappings:
        return ast_mappings[ast2str(ast)]
    ast = AST.recurse_down_tree(ast, _sub_subtrees_for_vars, (ast_mappings,))
    return ast
Ejemplo n.º 3
0
def sub_for_comps(expr, mapping):
    """
    For each pair out_name:in_expr in mapping, the returned string has all
    occurences of the variable out_compe substituted by in_expr.
    """
    if len(mapping) == 0:
        return expr

    ast = strip_parse(expr)
    ast_mapping = {}
    for out_expr, in_expr in list(mapping.items()):
        out_ast = strip_parse(out_expr)
        if not isinstance(out_ast, Compare):
            raise ValueError('Expression %s to substitute for is not a '\
                    'comparison.' % out_expr)
        ast_mapping[ast2str(out_ast)] = strip_parse(in_expr)

    ast = _sub_subtrees_for_comps(ast, ast_mapping)
    return ast2str(ast)
Ejemplo n.º 4
0
 def test__collect_num_denom(self):
     cases = [(strip_parse('x-x'), (['x - x'], [])),
              (strip_parse('1/2'), (['1'], ['2'])),
              (strip_parse('1/2*3'), (['1', '3'], ['2'])),
              (strip_parse('1/(2*3)'), (['1'], ['2', '3'])),
              (strip_parse('1/(2/3)'), (['1', '3'], ['2'])),
              (strip_parse('(1*2)*(3*4)'),
               (['1', '2', '3', '4'], [])),
              (strip_parse('(1*2)*(3/4)'),
               (['1', '2', '3'], ['4'])),
              (strip_parse('(1/2)*(3/4)'),
               (['1', '3'], ['2', '4'])),
              (strip_parse('(1/2)/(3/4)'),
               (['1', '4'], ['2', '3'])),
              ]
     for ast, (nums, denoms) in cases:
         n, d = [], []
         AST._collect_num_denom(ast, n, d)
         n = [ast2str(term) for term in n]
         d = [ast2str(term) for term in d]
         assert set(nums) == set(n)
         assert set(denoms) == set(d)
Ejemplo n.º 5
0
def make_c_compatible(expr):
    """
    Convert a python math string into one compatible with C.

    Substitute all python-style x**n exponents with pow(x, n).
    Replace all integer constants with float values to avoid integer
     casting problems (e.g. '1' -> '1.0').
    Replace 'and', 'or', and 'not' with C's '&&', '||', and '!'. This may be
     fragile if the parsing library changes in newer python versions.
    """
    ast = strip_parse(expr)
    ast = _make_c_compatible_ast(ast)
    return ast2str(ast)
Ejemplo n.º 6
0
 def test__collect_pos_neg(self):
     cases = [(strip_parse('-y + z'), (['z'], ['y'])),
              (strip_parse('1-2'), (['1'], ['2'])),
              (strip_parse('1-2+3'), (['1', '3'], ['2'])),
              (strip_parse('1-(2+3)'), (['1'], ['2', '3'])),
              (strip_parse('1-(2-3)'), (['1', '3'], ['2'])),
              (strip_parse('(1-2)-(3-4)'), (['1', '4'], ['2', '3'])),
              (strip_parse('(1+2)+(3+4)'),
               (['1', '2', '3', '4'], [])),
              (strip_parse('(1+2)+(3-4)'),
               (['1', '2', '3'], ['4'])),
              (strip_parse('(1-2)+(3-4)'),
               (['1', '3'], ['2', '4'])),
              (strip_parse('(1-2)-(3-4)'),
               (['1', '4'], ['2', '3'])),
              ]
     for ast, (poss, negs) in cases: 
         p, n = [], []
         AST._collect_pos_neg(ast, p, n)
         p = [ast2str(term) for term in p]
         n = [ast2str(term) for term in n]
         assert set(poss) == set(p)
         assert set(negs) == set(n)
Ejemplo n.º 7
0
def sub_for_func(expr, func_name, func_vars, func_expr):
    """
    Return a string with the function func_name substituted for its exploded 
    form.
    
    func_name: The name of the function.
    func_vars: A sequence variables used by the function expression
    func_expr: The expression for the function.
    For example:
        If f(x, y, z) = sqrt(z)*x*y-z
        func_name = 'f'
        func_vars = ['x', 'y', 'z']
        func_expr = 'sqrt(z)*x*y-z'

    As a special case, functions that take a variable number of arguments can
    use '*' for func_vars.
    For example:
        sub_for_func('or_func(or_func(A,D),B,C)', 'or_func', '*', 'x or y')
        yields '(A or D) or B or C'
    """
    ast = strip_parse(expr)
    func_name_ast = strip_parse(func_name)
    if not isinstance(func_name_ast, Name):
        raise ValueError('Function name is not a simple name.')
    func_name = func_name_ast.id

    func_expr_ast = strip_parse(func_expr)
    # We can strip_parse  the '*', so we special case it here.
    if func_vars == '*':
        if not hasattr(func_expr_ast, 'values'):
            raise ValueError("Top-level function in %s does not appear to "
                             "accept variable number of arguments. (It has no "
                             "'nodes' attribute.)" % func_expr)

        func_var_names = '*'
    else:
        func_vars_ast = [strip_parse(var) for var in func_vars]
        for var_ast in func_vars_ast:
            if not isinstance(var_ast, Name):
                raise ValueError('Function variable is not a simple name.')
        func_var_names = [getattr(var_ast, 'id') for var_ast in func_vars_ast]

    ast = _sub_for_func_ast(ast, func_name, func_var_names, func_expr_ast)
    simple = Simplify._simplify_ast(ast)
    return ast2str(simple)
Ejemplo n.º 8
0
 def test_ast2str(self):
     cases = ['x', 'x+y', 'x-y', 'x*y', 'x/y', 'x**y', '-x', 'x**-y',
              'x**(-y + z)', 'f(x)', 'g(x,y,z)', 'x**(y**z)', 
              '(x**y)**z', 'x**y**z', 'x - (x+y)', '(x+y) - z',
              'g(x-0+2, y**2 - 0**0, z*y + x/1)', 'x/x', 'x/y',
              '(x-x)/z', 'x**2 - y/z', 'x+1-1+2-3-x', '0+1*1', 'x-x+y',
              '(-2)**2', '-2**2', 'x < 0.5', 'x + y < 0.8', 'x > 0.5',
              'x+y < z - x', '(f(x) < 0.5) == False', 'x == x', 
              'x + y == 2*x - x + y', 'True and False', 
              'True and False or True', 'True or False', 
              '(True and not False) or True', 'not (True and False)',
              'x == x and y == y', 'x - x == 0 or y - x != 0']
     for expr in cases:
         run = ast2str(strip_parse(expr))
         orig = eval(expr)
         out = eval(run)
         if orig != 0:
             assert old_div(abs(orig - out),(0.5 * (orig + out))) < 1e-6
         else:
             assert out == 0
Ejemplo n.º 9
0
def _make_c_compatible_ast(ast):
    if isinstance(ast, BinOp) and isinstance(ast.op, Pow):
        ast = Call(func=Name(id='pow', ctx=Load()), args=[ast.left, ast.right], keywords=[])
        ast = AST.recurse_down_tree(ast, _make_c_compatible_ast)
    elif isinstance(ast, Constant) and isinstance(ast.value, int):
        ast.value = float(ast.value)
    elif isinstance(ast, Subscript):
        # These asts correspond to array[blah] and we shouldn't convert these
        # to floats, so we don't recurse down the tree in this case.
        pass
    # We need to subsitute the C logical operators. Unfortunately, they aren't
    # valid python syntax, so we have to cheat a little, using Compare and Name
    # nodes abusively. This abuse may not be future-proof... sigh...
    elif isinstance(ast, BoolOp) and isinstance(ast.op, And):
        nodes = AST.recurse_down_tree(ast.values, _make_c_compatible_ast)
        ops = [('&&', node) for node in nodes[1:]]
        ops_1 = []
        c = []
        for k, v in ops:
            ops_1.append(k)
            c.append(v)
            
        ast = Compare(nodes[0], ops_1, c)
    elif isinstance(ast, BoolOp) and isinstance(ast.op, Or):
        nodes = AST.recurse_down_tree(ast.values, _make_c_compatible_ast)
        ops = [('||', node) for node in nodes[1:]]
        ops_1 = []
        c = []
        for k, v in ops:
            ops_1.append(k)
            c.append(v)
            
        ast = AST.Compare(nodes[0], ops_1, c)
    elif isinstance(ast, UnaryOp) and isinstance(ast.op, Not):
        expr = AST.recurse_down_tree(ast.operand, _make_c_compatible_ast)
        ast = AST.Name(id='!(%s)' % ast2str(expr))
    else:
        ast = AST.recurse_down_tree(ast, _make_c_compatible_ast)
    return ast
Ejemplo n.º 10
0
def _sub_subtrees_for_comps(ast, ast_mappings):
    if isinstance(ast, Compare) and ast2str(ast) in ast_mappings:
        return ast_mappings[ast2str(ast)]
    ast = AST.recurse_down_tree(ast, _sub_subtrees_for_comps, (ast_mappings,))
    return ast