Example #1
0
def simplify_expr(expr):
    """
    Return a simplified version of the expression.
    """
    tree = AST.strip_parse(expr)
    simplify_ast = _simplify_ast(tree)
    return AST.ast2str(simplify_ast)
Example #2
0
def extract_comps(expr):
    """
    Extract all comparisons from the expression.
    """
    comps_found = []
    _extract_comps_ast(AST.strip_parse(expr), comps_found)
    comps_found = [AST.ast2str(ast) for ast in comps_found]
    return set(comps_found)
Example #3
0
def _extract_funcs_ast(ast, funcs_found):
    """
    Append ('name', #arg) for each function used in the ast to funcs_found.
    """
    if isinstance(ast, Call):
        funcs_found.append((AST.ast2str(ast.func), len(ast.args)))
        for node in ast.args:
            _extract_funcs_ast(node, funcs_found)
    ast = AST.recurse_down_tree(ast, _extract_funcs_ast, (funcs_found, ))
    return ast
Example #4
0
def extract_vars(expr):
    """
    Return a Set of the variables used in an expression.
    """
    try:
        return extract_vars_cache[expr]
    except KeyError:
        vars_found = []
        _extract_vars_ast(AST.strip_parse(expr), vars_found)
        vars_found = [AST.ast2str(ast) for ast in vars_found]
        result = set(vars_found)
        extract_vars_cache[expr] = result
        return result
Example #5
0
def _sub_for_func_ast(ast, func_name, func_vars, func_expr_ast):
    """
    Return an ast with the function func_name substituted out.
    """
    if isinstance(ast, Call) and ast2str(ast.func) == func_name\
       and func_vars == '*':
        working_ast = copy.deepcopy(func_expr_ast)
        new_args = [_sub_for_func_ast(arg_ast, func_name, func_vars, 
                                      func_expr_ast) for arg_ast in ast.args]
        # This subs out the arguments of the original function.
        working_ast.values = new_args
        return working_ast
    if isinstance(ast, Call) and ast2str(ast.func) == func_name\
       and len(ast.args) == len(func_vars):
        # If our ast is the function we're looking for, we take the ast
        #  for the function expression, substitute for its arguments, and
        #  return
        working_ast = copy.deepcopy(func_expr_ast)
        mapping = {}
        for var_name, arg_ast in zip(func_vars, ast.args):
            subbed_arg_ast = _sub_for_func_ast(arg_ast, func_name, func_vars, 
                                               func_expr_ast)
            mapping[var_name] = subbed_arg_ast
        _sub_subtrees_for_vars(working_ast, mapping)
        return working_ast
    ast = AST.recurse_down_tree(ast, _sub_for_func_ast, 
                                (func_name, func_vars, func_expr_ast,))
    return ast
Example #6
0
def extract_funcs(expr):
    """
    Return a Set of the functions used in an expression.
    The elements of the Set are ('function_name', #arguments).
    """
    funcs_found = []
    _extract_funcs_ast(AST.strip_parse(expr), funcs_found)
    return set(funcs_found)
Example #7
0
def diff_expr(expr, wrt):
    """
    Return the derivative of the expression with respect to a given variable.
    """
    logger.debug('Taking derivative of %s wrt %s' % (expr, wrt))
    key = '%s__derivWRT__%s' % (expr, wrt)
    if key in __deriv_saved:
        deriv = __deriv_saved[key]
        logger.debug('Found saved result %s.' % deriv)
        return deriv

    ast = AST.strip_parse(expr)
    deriv = _diff_ast(ast, wrt)
    deriv = Simplify._simplify_ast(deriv)
    deriv = AST.ast2str(deriv)
    __deriv_saved[key] = deriv
    logger.debug('Computed result %s.' % deriv)
    return deriv
Example #8
0
def _extract_vars_ast(ast, vars_found):
    """
    Appends the asts of the variables used in ast to vars_found.
    """
    if isinstance(ast, Name):
        if ast.id not in ['True', 'False']:
            vars_found.append(ast)
    ast = AST.recurse_down_tree(ast, _extract_vars_ast, (vars_found, ))
    return ast
Example #9
0
def expr2TeX(expr, name_dict={}):
    """
    Return a TeX version of a python math expression.

    name_dict: A dictionary mapping variable names used in the expression to
        preferred TeX expressions.
    """
    ast = AST.strip_parse(expr)
    return _ast2TeX(ast, name_dict=name_dict)
Example #10
0
def _sub_subtrees_for_vars(ast, ast_mappings):
    """
    For each out_name, in_ast pair in mappings, substitute in_ast for all 
    occurances of the variable named out_name in ast
    """
    if isinstance(ast, Name) and ast2str(ast) in ast_mappings:
        return ast_mappings[ast2str(ast)]
    ast = AST.recurse_down_tree(ast, _sub_subtrees_for_vars, (ast_mappings,))
    return ast
Example #11
0
def dict2TeX(d, name_dict, lhs_form='%s', split_terms=False, simpleTeX=False):
    lines = []
    for lhs, rhs in list(d.items()):
        if split_terms:
            ast = AST.strip_parse(rhs)
            pos, neg = [], []
            AST._collect_pos_neg(ast, pos, neg)
            try:
                lhsTeX = lhs_form % expr2TeX(lhs, name_dict=name_dict)
            except TypeError:
                lhsTeX = lhs_form
            rhsTeX = _ast2TeX(pos[0], name_dict=name_dict)
            lines.append(r'$ %s $ &=& $ %s $\\' % (lhsTeX, rhsTeX))

            for term in pos[1:]:
                TeXed = _ast2TeX(term, name_dict=name_dict)
                lines.append(r' & & $ + \, %s $\\' % TeXed)
            for term in neg:
                TeXed = _ast2TeX(term, name_dict=name_dict)
                lines.append(r' & & $ - \, %s $\\' % TeXed)
        else:
            lhsTeX = lhs_form % expr2TeX(lhs, name_dict=name_dict)
            rhsTeX = expr2TeX(rhs, name_dict=name_dict)
            lines.append(r'$ %s $ & = & $ %s $\\' % (lhsTeX, rhsTeX))

        if not simpleTeX:
            # Force a space between TeX'd entries
            lines[-1] = '%s[5mm]' % lines[-1]

    all = os.linesep.join(lines)

    if not simpleTeX:
        all = all.replace(r'\frac{', r'\tabfrac{')
        # This makes the fractions look much nicer in the tabular output. See
        #  http://www.texnik.de/table/table.phtml#fractions
        lines = [
            r'\providecommand{\tabfrac}[2]{%',
            r'   \setlength{\fboxrule}{0pt}%', r'   \fbox{$\frac{#1}{#2}$}}',
            r'\begin{longtable}{lll}'
        ] + [all] + [r'\end{longtable}']
        all = os.linesep.join(lines)

    return all
Example #12
0
 def test__collect_num_denom(self):
     cases = [(strip_parse('x-x'), (['x - x'], [])),
              (strip_parse('1/2'), (['1'], ['2'])),
              (strip_parse('1/2*3'), (['1', '3'], ['2'])),
              (strip_parse('1/(2*3)'), (['1'], ['2', '3'])),
              (strip_parse('1/(2/3)'), (['1', '3'], ['2'])),
              (strip_parse('(1*2)*(3*4)'),
               (['1', '2', '3', '4'], [])),
              (strip_parse('(1*2)*(3/4)'),
               (['1', '2', '3'], ['4'])),
              (strip_parse('(1/2)*(3/4)'),
               (['1', '3'], ['2', '4'])),
              (strip_parse('(1/2)/(3/4)'),
               (['1', '4'], ['2', '3'])),
              ]
     for ast, (nums, denoms) in cases:
         n, d = [], []
         AST._collect_num_denom(ast, n, d)
         n = [ast2str(term) for term in n]
         d = [ast2str(term) for term in d]
         assert set(nums) == set(n)
         assert set(denoms) == set(d)
Example #13
0
 def test__collect_num_denom(self):
     cases = [(strip_parse('1'), (['1'], [])),
              (strip_parse('1/2'), (['1'], ['2'])),
              (strip_parse('1/2*3'), (['1', '3'], ['2'])),
              (strip_parse('1/(2*3)'), (['1'], ['2', '3'])),
              (strip_parse('1/(2/3)'), (['1', '3'], ['2'])),
              (Mul((Mul((Const(1), Const(2))), Mul((Const(3), Const(4))))),
               (['1', '2', '3', '4'], [])),
              (Mul((Mul((Const(1), Const(2))), Div((Const(3), Const(4))))),
               (['1', '2', '3'], ['4'])),
              (Mul((Div((Const(1), Const(2))), Div((Const(3), Const(4))))),
               (['1', '3'], ['2', '4'])),
              (Div((Div((Const(1), Const(2))), Div((Const(3), Const(4))))),
               (['1', '4'], ['2', '3'])),
              ]
     for ast, (nums, denoms) in cases: 
         n, d = [], []
         AST._collect_num_denom(ast, n, d)
         n = [ast2str(term) for term in n]
         d = [ast2str(term) for term in d]
         assert sets.Set(nums) == sets.Set(n)
         assert sets.Set(denoms) == sets.Set(d)
Example #14
0
 def test__collect_pos_neg(self):
     cases = [(strip_parse('-y + z'), (['z'], ['y'])),
              (strip_parse('1-2'), (['1'], ['2'])),
              (strip_parse('1-2+3'), (['1', '3'], ['2'])),
              (strip_parse('1-(2+3)'), (['1'], ['2', '3'])),
              (strip_parse('1-(2-3)'), (['1', '3'], ['2'])),
              (strip_parse('(1-2)-(3-4)'), (['1', '4'], ['2', '3'])),
              (strip_parse('(1+2)+(3+4)'),
               (['1', '2', '3', '4'], [])),
              (strip_parse('(1+2)+(3-4)'),
               (['1', '2', '3'], ['4'])),
              (strip_parse('(1-2)+(3-4)'),
               (['1', '3'], ['2', '4'])),
              (strip_parse('(1-2)-(3-4)'),
               (['1', '4'], ['2', '3'])),
              ]
     for ast, (poss, negs) in cases: 
         p, n = [], []
         AST._collect_pos_neg(ast, p, n)
         p = [ast2str(term) for term in p]
         n = [ast2str(term) for term in n]
         assert set(poss) == set(p)
         assert set(negs) == set(n)
Example #15
0
 def test__collect_pos_neg(self):
     cases = [(strip_parse('1'), (['1'], [])),
              (strip_parse('1-2'), (['1'], ['2'])),
              (strip_parse('1-2+3'), (['1', '3'], ['2'])),
              (strip_parse('1-(2+3)'), (['1'], ['2', '3'])),
              (strip_parse('1-(2-3)'), (['1', '3'], ['2'])),
              (strip_parse('(1-2)-(3-4)'), (['1', '4'], ['2', '3'])),
              (Add((Add((Const(1), Const(2))), Add((Const(3), Const(4))))),
               (['1', '2', '3', '4'], [])),
              (Add((Add((Const(1), Const(2))), Sub((Const(3), Const(4))))),
               (['1', '2', '3'], ['4'])),
              (Add((Sub((Const(1), Const(2))), Sub((Const(3), Const(4))))),
               (['1', '3'], ['2', '4'])),
              (Sub((Sub((Const(1), Const(2))), Sub((Const(3), Const(4))))),
               (['1', '4'], ['2', '3'])),
              ]
     for ast, (poss, negs) in cases: 
         p, n = [], []
         AST._collect_pos_neg(ast, p, n)
         p = [ast2str(term) for term in p]
         n = [ast2str(term) for term in n]
         assert sets.Set(poss) == sets.Set(p)
         assert sets.Set(negs) == sets.Set(n)
Example #16
0
def _product_deriv(terms, wrt):
    """
    Return an AST expressing the derivative of the product of all the terms.
    """
    if len(terms) == 1:
        return _diff_ast(terms[0], wrt)
    deriv_terms = []
    for ii, term in enumerate(terms):
        term_d = _diff_ast(term, wrt)
        other_terms = terms[:ii] + terms[ii + 1:]
        deriv_terms.append(AST._make_product(other_terms + [term_d]))
    sum = deriv_terms[0]
    for term in deriv_terms[1:]:
        sum = BinOp(left=term, op=Add(), right=sum)
    return sum
Example #17
0
def _make_c_compatible_ast(ast):
    if isinstance(ast, BinOp) and isinstance(ast.op, Pow):
        ast = Call(func=Name(id='pow', ctx=Load()), args=[ast.left, ast.right], keywords=[])
        ast = AST.recurse_down_tree(ast, _make_c_compatible_ast)
    elif isinstance(ast, Constant) and isinstance(ast.value, int):
        ast.value = float(ast.value)
    elif isinstance(ast, Subscript):
        # These asts correspond to array[blah] and we shouldn't convert these
        # to floats, so we don't recurse down the tree in this case.
        pass
    # We need to subsitute the C logical operators. Unfortunately, they aren't
    # valid python syntax, so we have to cheat a little, using Compare and Name
    # nodes abusively. This abuse may not be future-proof... sigh...
    elif isinstance(ast, BoolOp) and isinstance(ast.op, And):
        nodes = AST.recurse_down_tree(ast.values, _make_c_compatible_ast)
        ops = [('&&', node) for node in nodes[1:]]
        ops_1 = []
        c = []
        for k, v in ops:
            ops_1.append(k)
            c.append(v)
            
        ast = Compare(nodes[0], ops_1, c)
    elif isinstance(ast, BoolOp) and isinstance(ast.op, Or):
        nodes = AST.recurse_down_tree(ast.values, _make_c_compatible_ast)
        ops = [('||', node) for node in nodes[1:]]
        ops_1 = []
        c = []
        for k, v in ops:
            ops_1.append(k)
            c.append(v)
            
        ast = AST.Compare(nodes[0], ops_1, c)
    elif isinstance(ast, UnaryOp) and isinstance(ast.op, Not):
        expr = AST.recurse_down_tree(ast.operand, _make_c_compatible_ast)
        ast = AST.Name(id='!(%s)' % ast2str(expr))
    else:
        ast = AST.recurse_down_tree(ast, _make_c_compatible_ast)
    return ast
Example #18
0
def _simplify_ast(ast):
    """
    Return a simplified ast.

    Current simplifications:
        Special cases for zeros and ones, and combining of constants, in 
            addition, subtraction, multiplication, division.
        Note that at present we only handle constants applied left to right.
          1+1+x -> 2+x, but x+1+1 -> x+1+1.
        x - x = 0
        --x = x
    """
    if isinstance(ast, Name) or isinstance(ast, Constant):
        return ast
    elif isinstance(ast, BinOp) and (isinstance(ast.op, Add)
                                     or isinstance(ast.op, Sub)):

        # We collect positive and negative terms and simplify each of them
        pos, neg = [], []
        AST._collect_pos_neg(ast, pos, neg)

        pos = [_simplify_ast(term) for term in pos]
        neg = [_simplify_ast(term) for term in neg]
        # We collect and sum the constant values
        values = [term.value for term in pos if isinstance(term, Constant)] +\
                [-term.value for term in neg if isinstance(term, Constant)]
        value = sum(values)
        # Remove the constants from our pos and neg lists
        pos = [term for term in pos if not isinstance(term, Constant)]
        neg = [term for term in neg if not isinstance(term, Constant)]
        new_pos, new_neg = [], []
        for term in pos:
            if isinstance(term, UnaryOp):
                if isinstance(term.op, USub):
                    new_neg.append(term.operand)
            else:
                new_pos.append(term)
        for term in neg:
            if isinstance(term, UnaryOp):
                if isinstance(term.op, USub):
                    new_pos.append(term.operand)
            else:
                new_neg.append(term)
        pos, neg = new_pos, new_neg
        # Append the constant value sum to pos or neg
        if value > 0:
            pos.append(Constant(value=value))
        elif value < 0:
            neg.append(Constant(value=abs(value)))
        # Count the number of occurances of each term.
        term_counts = [
            (term,
             get_count_from_ast(pos, term) - get_count_from_ast(neg, term))
            for term in pos + neg
        ]
        # Tricky: We use the str(term) as the key for the dictionary to ensure
        #         that each entry represents a unique term. We also drop terms
        #         that have a total count of 0.
        term_counts = dict([(AST.ast2str(term), (term, count))
                            for term, count in term_counts])
        # We find the first term with non-zero count.
        ii = 0
        for ii, term in enumerate(pos + neg):
            ast_out, count = term_counts[AST.ast2str(term)]
            if count != 0:
                break
        else:
            # We get here if we don't break out of the loop, implying that
            #  all our terms had count of 0
            return _ZERO
        term_counts[AST.ast2str(term)] = (ast_out, 0)
        if abs(count) != 1:
            ast_out = BinOp(left=Constant(value=abs(count)),
                            op=Mult(),
                            right=ast_out)
        if count < 0:
            ast_out = UnaryOp(op=USub(), operand=ast_out)

        # And add in all the rest
        for term in (pos + neg)[ii:]:
            term, count = term_counts[AST.ast2str(term)]
            term_counts[AST.ast2str(term)] = (term, 0)
            if abs(count) != 1:
                term = BinOp(left=Constant(value=abs(count)),
                             op=Mult(),
                             right=term)
            if count > 0:
                ast_out = BinOp(left=ast_out, op=Add(), right=term)
            elif count < 0:
                ast_out = BinOp(left=ast_out, op=Sub(), right=term)
        return ast_out

    elif isinstance(ast, BinOp) and (isinstance(ast.op, Mult)
                                     or isinstance(ast.op, Div)):
        # We collect numerator and denominator terms and simplify each of them
        num, denom = [], []
        AST._collect_num_denom(ast, num, denom)
        num = [_simplify_ast(term) for term in num]
        denom = [_simplify_ast(term) for term in denom]
        # We collect and sum the constant values
        values = [term.value for term in num if isinstance(term, Constant)] +\
                [1./term.value for term in denom if isinstance(term, Constant)]
        # This takes the product of all our values
        value = functools.reduce(operator.mul, values + [1])
        # If our value is 0, the expression is 0
        if not value:
            return _ZERO
        # Remove the constants from our pos and neg lists
        num = [term for term in num if not isinstance(term, Constant)]
        denom = [term for term in denom if not isinstance(term, Constant)]
        # Here we count all the negative (UnarySub) elements of our expression.
        # We also remove the UnarySubs from their arguments. We'll correct
        #  for it at the end.
        num_neg = 0
        for list_of_terms in [num, denom]:
            for ii, term in enumerate(list_of_terms):
                if isinstance(term, UnaryOp) and isinstance(term.op, USub):
                    list_of_terms[ii] = term.operand
                    num_neg += 1

        # Append the constant value sum to pos or neg
        if abs(value) != 1:
            num.append(Constant(value=abs(value)))
        if value < 0:
            num_neg += 1

        make_neg = num_neg % 2
        # Count the number of occurances of each term.
        term_counts = [
            (term,
             get_count_from_ast(num, term) - get_count_from_ast(denom, term))
            for term in num + denom
        ]
        # Tricky: We use the str(term) as the key for the dictionary to ensure
        #         that each entry represents a unique term. We also drop terms
        #         that have a total count of 0.
        term_counts = dict([(AST.ast2str(term), (term, count))
                            for term, count in term_counts])

        nums, denoms = [], []
        # We walk through terms in num+denom in order, so we rearrange a little
        #  as possible.
        for term in num + denom:
            term, count = term_counts[AST.ast2str(term)]
            # Once a term has been done, we set its term_counts to 0, so it
            #  doesn't get done again.
            term_counts[AST.ast2str(term)] = (term, 0)
            if abs(count) > 1:
                term = BinOp(left=term,
                             op=Pow(),
                             right=Constant(value=abs(count)))
            if count > 0:
                nums.append(term)
            elif count < 0:
                denoms.append(term)

        # We return the product of the numerator terms over the product of the
        #  denominator terms
        out = AST._make_product(nums)
        if denoms:
            denom = AST._make_product(denoms)
            out = BinOp(left=out, op=Div(), right=denom)

        if make_neg:
            out = UnaryOp(op=USub(), operand=out)
        return out
    elif isinstance(ast, BinOp) and isinstance(ast.op, Pow):
        # These cases all have a left and a right, so we group them just to
        #  avoid some code duplication.
        power = _simplify_ast(ast.right)
        base = _simplify_ast(ast.left)

        if power == _ZERO:
            # Anything, including 0, to the 0th power is 1, so this
            #  test should come first
            return _ONE
        if base == _ZERO or base == _ONE or power == _ONE:
            return base
        elif isinstance(base, Constant) and\
                isinstance(power, Constant):
            return Constant(value=base.value**power.value)
        # Getting here implies that no simplifications are possible, so just
        #  return with simplified arguments
        return BinOp(left=base, op=Pow(), right=power)

    elif isinstance(ast, UnaryOp) and isinstance(ast.op, USub):
        simple_expr = _simplify_ast(ast.operand)
        if isinstance(simple_expr, UnaryOp) and isinstance(
                simple_expr.op, USub):
            # Case --x
            return _simplify_ast(simple_expr.operand)
        elif isinstance(simple_expr, Constant):
            if simple_expr.value == 0:
                return Constant(value=0)
            else:
                return Constant(value=-simple_expr.value)
        else:
            return UnaryOp(op=USub(), operand=simple_expr)
    elif isinstance(ast, UnaryOp) and isinstance(ast.op, UAdd):
        simple_expr = _simplify_ast(ast.operand)
        return simple_expr
    elif isinstance(ast, list):
        simple_list = [_simplify_ast(elem) for elem in ast]
        return simple_list
    elif isinstance(ast, tuple):
        return tuple(_simplify_ast(list(ast)))
    elif ast.__class__ in AST._node_attrs:
        # Handle node types with no special cases.
        for attr_name in AST._node_attrs[ast.__class__]:
            attr = getattr(ast, attr_name)
            if isinstance(attr, list):
                for ii, elem in enumerate(attr):
                    attr[ii] = _simplify_ast(elem)
            else:
                setattr(ast, attr_name, _simplify_ast(attr))
        return ast
    else:
        return ast
Example #19
0
def _diff_ast(ast, wrt):
    """
    Return an AST that is the derivative of ast with respect the variable with
    name 'wrt'.
    """

    # For now, the strategy is to return the most general forms, and let
    #  the simplifier take care of the special cases.
    if isinstance(ast, Name):
        if ast.id == wrt:
            return _ONE
        else:
            return _ZERO
    elif isinstance(ast, Constant):
        return _ZERO
    elif isinstance(ast, BinOp) and (isinstance(ast.op, Add)
                                     or isinstance(ast.op, Sub)):
        # Just take the derivative of the arguments. The call to ast.__class__
        #  lets us use the same code from Add and Sub.
        return (BinOp(left=_diff_ast(ast.left, wrt),
                      op=ast.op,
                      right=_diff_ast(ast.right, wrt)))
    elif isinstance(ast, BinOp) and (isinstance(ast.op, Mult)
                                     or isinstance(ast.op, Div)):
        # Collect all the numerators and denominators together
        nums, denoms = [], []
        AST._collect_num_denom(ast, nums, denoms)

        # Collect the numerator terms into a single AST
        num = AST._make_product(nums)
        # Take the derivative of the numerator terms as a product
        num_d = _product_deriv(nums, wrt)
        if not denoms:
            # If there is no denominator
            return num_d

        denom = AST._make_product(denoms)
        denom_d = _product_deriv(denoms, wrt)

        # Derivative of x/y is x'/y + -x*y'/y**2
        term1 = BinOp(left=num_d, op=Div(), right=denom)
        term2 = BinOp(left=BinOp(left=UnaryOp(op=USub(), operand=num),
                                 op=Mult(),
                                 right=denom_d),
                      op=Div(),
                      right=BinOp(left=denom,
                                  op=Pow(),
                                  right=Constant(value=2)))
        return BinOp(left=term1, op=Add(), right=term2)

    elif isinstance(ast, BinOp) and isinstance(ast.op, Pow):
        # Use the derivative of the 'pow' function
        ast = Call(func=Name(id='pow', ctx=Load()), args=[ast.left, ast.right])
        return _diff_ast(ast, wrt)

    elif isinstance(ast, Call):
        func_name = AST.ast2str(ast.func)
        args = ast.args
        args_d = [_diff_ast(arg, wrt) for arg in args]

        if (func_name, len(args)) in _KNOWN_FUNCS:
            form = copy.deepcopy(_KNOWN_FUNCS[(func_name, len(args))])
        else:
            # If this isn't a known function, our form is
            #  (f_0(args), f_1(args), ...)
            args_expr = [
                Name(id='arg%i' % ii, ctx=Load()) for ii in range(len(args))
            ]
            form = [
                Call(func=Name(id='%s_%i' % (func_name, ii), ctx=Load()),
                     args=args_expr,
                     keywords=[]) for ii in range(len(args))
            ]

        # We build up the terms in our derivative
        #  f_0(x,y)*x' + f_1(x,y)*y', etc.
        outs = []
        for arg_d, arg_form_d in zip(args_d, form):
            # We skip arguments with 0 derivative
            if arg_d == _ZERO:
                continue
            for ii, arg in enumerate(args):
                Substitution._sub_subtrees_for_vars(arg_form_d,
                                                    {'arg%i' % ii: arg})
            outs.append(BinOp(left=arg_form_d, op=Mult(), right=arg_d))
        # If all arguments had zero deriviative
        if not outs:
            return _ZERO
        else:
            # We add up all our terms
            ret = outs[0]
            for term in outs[1:]:
                ret = BinOp(left=ret, op=Add(), right=term)
            return ret

    elif isinstance(ast, UnaryOp) and isinstance(ast.op, USub):
        return UnaryOp(op=USub(), operand=_diff_ast(ast.operand, wrt))

    elif isinstance(ast, UnaryOp) and isinstance(ast.op, UAdd):
        return UnaryOp(op=UAdd(), operand=_diff_ast(ast.operand, wrt))
Example #20
0
def _ast2TeX(ast, outer=AST._FARTHEST_OUT, name_dict={}, adjust=0):
    """
    Return a TeX version of an AST.

    outer: The AST's 'parent' node, used to determine whether or not to 
        enclose the result in parentheses. The default of _FARTHEST_OUT will
        never enclose the result in parentheses.

    name_dict: A dictionary mapping variable names used in the expression to
        preferred TeX expressions.

    adjust: A numerical value to adjust the priority of this ast for
        particular cases. For example, the denominator of a '/' needs 
        parentheses in more cases than does the numerator.
    """
    if isinstance(ast, Name):
        # Try to get a value from the name_dict, defaulting to ast.name if
        #  ast.name isn't in name_dict
        out = name_dict.get(ast.id, ast.id)
    elif isinstance(ast, Constant):
        out = str(ast.value)
    elif isinstance(ast, BinOp) and isinstance(ast.op, Add):
        out = '%s + %s' % (_ast2TeX(
            ast.left, ast, name_dict), _ast2TeX(ast.right, ast, name_dict))
    elif isinstance(ast, BinOp) and isinstance(ast.op, Sub):
        out = '%s - %s' % (_ast2TeX(ast.left, ast, name_dict),
                           _ast2TeX(ast.right, ast, name_dict, adjust=1))
    elif isinstance(ast, BinOp) and (isinstance(ast.op, Mult)
                                     or isinstance(ast.op, Div)):
        # We collect all terms numerator and denominator
        nums, denoms = [], []
        AST._collect_num_denom(ast, nums, denoms)
        # _EMPTY_MUL ensures that parentheses are done properly, since every
        #  element is now the child of a Mul
        lam_func = lambda arg: _ast2TeX(arg, _EMPTY_MUL, name_dict)
        nums = [lam_func(term) for term in nums]
        if denoms:
            denoms = [lam_func(term) for term in denoms]
            out = r'\frac{%s}{%s}' % (r' \cdot '.join(nums),
                                      r' \cdot '.join(denoms))
        else:
            out = r' \cdot '.join(nums)
    elif isinstance(ast, BinOp) and isinstance(ast.op, Pow):
        out = '{%s}^{%s}' % (_ast2TeX(ast.left, ast, name_dict, adjust=1),
                             _ast2TeX(ast.right, ast, name_dict))
    elif isinstance(ast, UnaryOp) and isinstance(ast.op, USub):
        out = '-%s' % _ast2TeX(ast.operand, ast, name_dict)
    elif isinstance(ast, UnaryOp) and isinstance(ast.op, UAdd):
        out = '+%s' % _ast2TeX(ast.operand, ast, name_dict)
    elif isinstance(ast, Call):
        lam_func = lambda arg: _ast2TeX(arg, name_dict=name_dict)
        name = lam_func(ast.func)
        args = [lam_func(arg) for arg in ast.args]
        if name == 'sqrt' and len(args) == 1:
            # Special case
            out = r'\sqrt{%s}' % args[0]
        else:
            out = r'\operatorname{%s}\left(%s\right)' % (name,
                                                         r',\,'.join(args))
    elif isinstance(ast, BoolOp) and isinstance(ast.op, Or) or isinstance(
            ast.op, And):
        out = r'\operatorname{%s}' % (str(ast))

    if AST._need_parens(outer, ast, adjust):
        return out
    else:
        return r'\left(%s\right)' % out
Example #21
0
def get_count_from_ast(ast_list, term):
    count = 0
    for item in ast_list:
        if AST.ast2str(item) == AST.ast2str(term):
            count += 1
    return count
Example #22
0
def _sub_subtrees_for_comps(ast, ast_mappings):
    if isinstance(ast, Compare) and ast2str(ast) in ast_mappings:
        return ast_mappings[ast2str(ast)]
    ast = AST.recurse_down_tree(ast, _sub_subtrees_for_comps, (ast_mappings,))
    return ast