def simplify_expr(expr): """ Return a simplified version of the expression. """ tree = AST.strip_parse(expr) simplify_ast = _simplify_ast(tree) return AST.ast2str(simplify_ast)
def extract_comps(expr): """ Extract all comparisons from the expression. """ comps_found = [] _extract_comps_ast(AST.strip_parse(expr), comps_found) comps_found = [AST.ast2str(ast) for ast in comps_found] return set(comps_found)
def _extract_funcs_ast(ast, funcs_found): """ Append ('name', #arg) for each function used in the ast to funcs_found. """ if isinstance(ast, Call): funcs_found.append((AST.ast2str(ast.func), len(ast.args))) for node in ast.args: _extract_funcs_ast(node, funcs_found) ast = AST.recurse_down_tree(ast, _extract_funcs_ast, (funcs_found, )) return ast
def extract_vars(expr): """ Return a Set of the variables used in an expression. """ try: return extract_vars_cache[expr] except KeyError: vars_found = [] _extract_vars_ast(AST.strip_parse(expr), vars_found) vars_found = [AST.ast2str(ast) for ast in vars_found] result = set(vars_found) extract_vars_cache[expr] = result return result
def _sub_for_func_ast(ast, func_name, func_vars, func_expr_ast): """ Return an ast with the function func_name substituted out. """ if isinstance(ast, Call) and ast2str(ast.func) == func_name\ and func_vars == '*': working_ast = copy.deepcopy(func_expr_ast) new_args = [_sub_for_func_ast(arg_ast, func_name, func_vars, func_expr_ast) for arg_ast in ast.args] # This subs out the arguments of the original function. working_ast.values = new_args return working_ast if isinstance(ast, Call) and ast2str(ast.func) == func_name\ and len(ast.args) == len(func_vars): # If our ast is the function we're looking for, we take the ast # for the function expression, substitute for its arguments, and # return working_ast = copy.deepcopy(func_expr_ast) mapping = {} for var_name, arg_ast in zip(func_vars, ast.args): subbed_arg_ast = _sub_for_func_ast(arg_ast, func_name, func_vars, func_expr_ast) mapping[var_name] = subbed_arg_ast _sub_subtrees_for_vars(working_ast, mapping) return working_ast ast = AST.recurse_down_tree(ast, _sub_for_func_ast, (func_name, func_vars, func_expr_ast,)) return ast
def extract_funcs(expr): """ Return a Set of the functions used in an expression. The elements of the Set are ('function_name', #arguments). """ funcs_found = [] _extract_funcs_ast(AST.strip_parse(expr), funcs_found) return set(funcs_found)
def diff_expr(expr, wrt): """ Return the derivative of the expression with respect to a given variable. """ logger.debug('Taking derivative of %s wrt %s' % (expr, wrt)) key = '%s__derivWRT__%s' % (expr, wrt) if key in __deriv_saved: deriv = __deriv_saved[key] logger.debug('Found saved result %s.' % deriv) return deriv ast = AST.strip_parse(expr) deriv = _diff_ast(ast, wrt) deriv = Simplify._simplify_ast(deriv) deriv = AST.ast2str(deriv) __deriv_saved[key] = deriv logger.debug('Computed result %s.' % deriv) return deriv
def _extract_vars_ast(ast, vars_found): """ Appends the asts of the variables used in ast to vars_found. """ if isinstance(ast, Name): if ast.id not in ['True', 'False']: vars_found.append(ast) ast = AST.recurse_down_tree(ast, _extract_vars_ast, (vars_found, )) return ast
def expr2TeX(expr, name_dict={}): """ Return a TeX version of a python math expression. name_dict: A dictionary mapping variable names used in the expression to preferred TeX expressions. """ ast = AST.strip_parse(expr) return _ast2TeX(ast, name_dict=name_dict)
def _sub_subtrees_for_vars(ast, ast_mappings): """ For each out_name, in_ast pair in mappings, substitute in_ast for all occurances of the variable named out_name in ast """ if isinstance(ast, Name) and ast2str(ast) in ast_mappings: return ast_mappings[ast2str(ast)] ast = AST.recurse_down_tree(ast, _sub_subtrees_for_vars, (ast_mappings,)) return ast
def dict2TeX(d, name_dict, lhs_form='%s', split_terms=False, simpleTeX=False): lines = [] for lhs, rhs in list(d.items()): if split_terms: ast = AST.strip_parse(rhs) pos, neg = [], [] AST._collect_pos_neg(ast, pos, neg) try: lhsTeX = lhs_form % expr2TeX(lhs, name_dict=name_dict) except TypeError: lhsTeX = lhs_form rhsTeX = _ast2TeX(pos[0], name_dict=name_dict) lines.append(r'$ %s $ &=& $ %s $\\' % (lhsTeX, rhsTeX)) for term in pos[1:]: TeXed = _ast2TeX(term, name_dict=name_dict) lines.append(r' & & $ + \, %s $\\' % TeXed) for term in neg: TeXed = _ast2TeX(term, name_dict=name_dict) lines.append(r' & & $ - \, %s $\\' % TeXed) else: lhsTeX = lhs_form % expr2TeX(lhs, name_dict=name_dict) rhsTeX = expr2TeX(rhs, name_dict=name_dict) lines.append(r'$ %s $ & = & $ %s $\\' % (lhsTeX, rhsTeX)) if not simpleTeX: # Force a space between TeX'd entries lines[-1] = '%s[5mm]' % lines[-1] all = os.linesep.join(lines) if not simpleTeX: all = all.replace(r'\frac{', r'\tabfrac{') # This makes the fractions look much nicer in the tabular output. See # http://www.texnik.de/table/table.phtml#fractions lines = [ r'\providecommand{\tabfrac}[2]{%', r' \setlength{\fboxrule}{0pt}%', r' \fbox{$\frac{#1}{#2}$}}', r'\begin{longtable}{lll}' ] + [all] + [r'\end{longtable}'] all = os.linesep.join(lines) return all
def test__collect_num_denom(self): cases = [(strip_parse('x-x'), (['x - x'], [])), (strip_parse('1/2'), (['1'], ['2'])), (strip_parse('1/2*3'), (['1', '3'], ['2'])), (strip_parse('1/(2*3)'), (['1'], ['2', '3'])), (strip_parse('1/(2/3)'), (['1', '3'], ['2'])), (strip_parse('(1*2)*(3*4)'), (['1', '2', '3', '4'], [])), (strip_parse('(1*2)*(3/4)'), (['1', '2', '3'], ['4'])), (strip_parse('(1/2)*(3/4)'), (['1', '3'], ['2', '4'])), (strip_parse('(1/2)/(3/4)'), (['1', '4'], ['2', '3'])), ] for ast, (nums, denoms) in cases: n, d = [], [] AST._collect_num_denom(ast, n, d) n = [ast2str(term) for term in n] d = [ast2str(term) for term in d] assert set(nums) == set(n) assert set(denoms) == set(d)
def test__collect_num_denom(self): cases = [(strip_parse('1'), (['1'], [])), (strip_parse('1/2'), (['1'], ['2'])), (strip_parse('1/2*3'), (['1', '3'], ['2'])), (strip_parse('1/(2*3)'), (['1'], ['2', '3'])), (strip_parse('1/(2/3)'), (['1', '3'], ['2'])), (Mul((Mul((Const(1), Const(2))), Mul((Const(3), Const(4))))), (['1', '2', '3', '4'], [])), (Mul((Mul((Const(1), Const(2))), Div((Const(3), Const(4))))), (['1', '2', '3'], ['4'])), (Mul((Div((Const(1), Const(2))), Div((Const(3), Const(4))))), (['1', '3'], ['2', '4'])), (Div((Div((Const(1), Const(2))), Div((Const(3), Const(4))))), (['1', '4'], ['2', '3'])), ] for ast, (nums, denoms) in cases: n, d = [], [] AST._collect_num_denom(ast, n, d) n = [ast2str(term) for term in n] d = [ast2str(term) for term in d] assert sets.Set(nums) == sets.Set(n) assert sets.Set(denoms) == sets.Set(d)
def test__collect_pos_neg(self): cases = [(strip_parse('-y + z'), (['z'], ['y'])), (strip_parse('1-2'), (['1'], ['2'])), (strip_parse('1-2+3'), (['1', '3'], ['2'])), (strip_parse('1-(2+3)'), (['1'], ['2', '3'])), (strip_parse('1-(2-3)'), (['1', '3'], ['2'])), (strip_parse('(1-2)-(3-4)'), (['1', '4'], ['2', '3'])), (strip_parse('(1+2)+(3+4)'), (['1', '2', '3', '4'], [])), (strip_parse('(1+2)+(3-4)'), (['1', '2', '3'], ['4'])), (strip_parse('(1-2)+(3-4)'), (['1', '3'], ['2', '4'])), (strip_parse('(1-2)-(3-4)'), (['1', '4'], ['2', '3'])), ] for ast, (poss, negs) in cases: p, n = [], [] AST._collect_pos_neg(ast, p, n) p = [ast2str(term) for term in p] n = [ast2str(term) for term in n] assert set(poss) == set(p) assert set(negs) == set(n)
def test__collect_pos_neg(self): cases = [(strip_parse('1'), (['1'], [])), (strip_parse('1-2'), (['1'], ['2'])), (strip_parse('1-2+3'), (['1', '3'], ['2'])), (strip_parse('1-(2+3)'), (['1'], ['2', '3'])), (strip_parse('1-(2-3)'), (['1', '3'], ['2'])), (strip_parse('(1-2)-(3-4)'), (['1', '4'], ['2', '3'])), (Add((Add((Const(1), Const(2))), Add((Const(3), Const(4))))), (['1', '2', '3', '4'], [])), (Add((Add((Const(1), Const(2))), Sub((Const(3), Const(4))))), (['1', '2', '3'], ['4'])), (Add((Sub((Const(1), Const(2))), Sub((Const(3), Const(4))))), (['1', '3'], ['2', '4'])), (Sub((Sub((Const(1), Const(2))), Sub((Const(3), Const(4))))), (['1', '4'], ['2', '3'])), ] for ast, (poss, negs) in cases: p, n = [], [] AST._collect_pos_neg(ast, p, n) p = [ast2str(term) for term in p] n = [ast2str(term) for term in n] assert sets.Set(poss) == sets.Set(p) assert sets.Set(negs) == sets.Set(n)
def _product_deriv(terms, wrt): """ Return an AST expressing the derivative of the product of all the terms. """ if len(terms) == 1: return _diff_ast(terms[0], wrt) deriv_terms = [] for ii, term in enumerate(terms): term_d = _diff_ast(term, wrt) other_terms = terms[:ii] + terms[ii + 1:] deriv_terms.append(AST._make_product(other_terms + [term_d])) sum = deriv_terms[0] for term in deriv_terms[1:]: sum = BinOp(left=term, op=Add(), right=sum) return sum
def _make_c_compatible_ast(ast): if isinstance(ast, BinOp) and isinstance(ast.op, Pow): ast = Call(func=Name(id='pow', ctx=Load()), args=[ast.left, ast.right], keywords=[]) ast = AST.recurse_down_tree(ast, _make_c_compatible_ast) elif isinstance(ast, Constant) and isinstance(ast.value, int): ast.value = float(ast.value) elif isinstance(ast, Subscript): # These asts correspond to array[blah] and we shouldn't convert these # to floats, so we don't recurse down the tree in this case. pass # We need to subsitute the C logical operators. Unfortunately, they aren't # valid python syntax, so we have to cheat a little, using Compare and Name # nodes abusively. This abuse may not be future-proof... sigh... elif isinstance(ast, BoolOp) and isinstance(ast.op, And): nodes = AST.recurse_down_tree(ast.values, _make_c_compatible_ast) ops = [('&&', node) for node in nodes[1:]] ops_1 = [] c = [] for k, v in ops: ops_1.append(k) c.append(v) ast = Compare(nodes[0], ops_1, c) elif isinstance(ast, BoolOp) and isinstance(ast.op, Or): nodes = AST.recurse_down_tree(ast.values, _make_c_compatible_ast) ops = [('||', node) for node in nodes[1:]] ops_1 = [] c = [] for k, v in ops: ops_1.append(k) c.append(v) ast = AST.Compare(nodes[0], ops_1, c) elif isinstance(ast, UnaryOp) and isinstance(ast.op, Not): expr = AST.recurse_down_tree(ast.operand, _make_c_compatible_ast) ast = AST.Name(id='!(%s)' % ast2str(expr)) else: ast = AST.recurse_down_tree(ast, _make_c_compatible_ast) return ast
def _simplify_ast(ast): """ Return a simplified ast. Current simplifications: Special cases for zeros and ones, and combining of constants, in addition, subtraction, multiplication, division. Note that at present we only handle constants applied left to right. 1+1+x -> 2+x, but x+1+1 -> x+1+1. x - x = 0 --x = x """ if isinstance(ast, Name) or isinstance(ast, Constant): return ast elif isinstance(ast, BinOp) and (isinstance(ast.op, Add) or isinstance(ast.op, Sub)): # We collect positive and negative terms and simplify each of them pos, neg = [], [] AST._collect_pos_neg(ast, pos, neg) pos = [_simplify_ast(term) for term in pos] neg = [_simplify_ast(term) for term in neg] # We collect and sum the constant values values = [term.value for term in pos if isinstance(term, Constant)] +\ [-term.value for term in neg if isinstance(term, Constant)] value = sum(values) # Remove the constants from our pos and neg lists pos = [term for term in pos if not isinstance(term, Constant)] neg = [term for term in neg if not isinstance(term, Constant)] new_pos, new_neg = [], [] for term in pos: if isinstance(term, UnaryOp): if isinstance(term.op, USub): new_neg.append(term.operand) else: new_pos.append(term) for term in neg: if isinstance(term, UnaryOp): if isinstance(term.op, USub): new_pos.append(term.operand) else: new_neg.append(term) pos, neg = new_pos, new_neg # Append the constant value sum to pos or neg if value > 0: pos.append(Constant(value=value)) elif value < 0: neg.append(Constant(value=abs(value))) # Count the number of occurances of each term. term_counts = [ (term, get_count_from_ast(pos, term) - get_count_from_ast(neg, term)) for term in pos + neg ] # Tricky: We use the str(term) as the key for the dictionary to ensure # that each entry represents a unique term. We also drop terms # that have a total count of 0. term_counts = dict([(AST.ast2str(term), (term, count)) for term, count in term_counts]) # We find the first term with non-zero count. ii = 0 for ii, term in enumerate(pos + neg): ast_out, count = term_counts[AST.ast2str(term)] if count != 0: break else: # We get here if we don't break out of the loop, implying that # all our terms had count of 0 return _ZERO term_counts[AST.ast2str(term)] = (ast_out, 0) if abs(count) != 1: ast_out = BinOp(left=Constant(value=abs(count)), op=Mult(), right=ast_out) if count < 0: ast_out = UnaryOp(op=USub(), operand=ast_out) # And add in all the rest for term in (pos + neg)[ii:]: term, count = term_counts[AST.ast2str(term)] term_counts[AST.ast2str(term)] = (term, 0) if abs(count) != 1: term = BinOp(left=Constant(value=abs(count)), op=Mult(), right=term) if count > 0: ast_out = BinOp(left=ast_out, op=Add(), right=term) elif count < 0: ast_out = BinOp(left=ast_out, op=Sub(), right=term) return ast_out elif isinstance(ast, BinOp) and (isinstance(ast.op, Mult) or isinstance(ast.op, Div)): # We collect numerator and denominator terms and simplify each of them num, denom = [], [] AST._collect_num_denom(ast, num, denom) num = [_simplify_ast(term) for term in num] denom = [_simplify_ast(term) for term in denom] # We collect and sum the constant values values = [term.value for term in num if isinstance(term, Constant)] +\ [1./term.value for term in denom if isinstance(term, Constant)] # This takes the product of all our values value = functools.reduce(operator.mul, values + [1]) # If our value is 0, the expression is 0 if not value: return _ZERO # Remove the constants from our pos and neg lists num = [term for term in num if not isinstance(term, Constant)] denom = [term for term in denom if not isinstance(term, Constant)] # Here we count all the negative (UnarySub) elements of our expression. # We also remove the UnarySubs from their arguments. We'll correct # for it at the end. num_neg = 0 for list_of_terms in [num, denom]: for ii, term in enumerate(list_of_terms): if isinstance(term, UnaryOp) and isinstance(term.op, USub): list_of_terms[ii] = term.operand num_neg += 1 # Append the constant value sum to pos or neg if abs(value) != 1: num.append(Constant(value=abs(value))) if value < 0: num_neg += 1 make_neg = num_neg % 2 # Count the number of occurances of each term. term_counts = [ (term, get_count_from_ast(num, term) - get_count_from_ast(denom, term)) for term in num + denom ] # Tricky: We use the str(term) as the key for the dictionary to ensure # that each entry represents a unique term. We also drop terms # that have a total count of 0. term_counts = dict([(AST.ast2str(term), (term, count)) for term, count in term_counts]) nums, denoms = [], [] # We walk through terms in num+denom in order, so we rearrange a little # as possible. for term in num + denom: term, count = term_counts[AST.ast2str(term)] # Once a term has been done, we set its term_counts to 0, so it # doesn't get done again. term_counts[AST.ast2str(term)] = (term, 0) if abs(count) > 1: term = BinOp(left=term, op=Pow(), right=Constant(value=abs(count))) if count > 0: nums.append(term) elif count < 0: denoms.append(term) # We return the product of the numerator terms over the product of the # denominator terms out = AST._make_product(nums) if denoms: denom = AST._make_product(denoms) out = BinOp(left=out, op=Div(), right=denom) if make_neg: out = UnaryOp(op=USub(), operand=out) return out elif isinstance(ast, BinOp) and isinstance(ast.op, Pow): # These cases all have a left and a right, so we group them just to # avoid some code duplication. power = _simplify_ast(ast.right) base = _simplify_ast(ast.left) if power == _ZERO: # Anything, including 0, to the 0th power is 1, so this # test should come first return _ONE if base == _ZERO or base == _ONE or power == _ONE: return base elif isinstance(base, Constant) and\ isinstance(power, Constant): return Constant(value=base.value**power.value) # Getting here implies that no simplifications are possible, so just # return with simplified arguments return BinOp(left=base, op=Pow(), right=power) elif isinstance(ast, UnaryOp) and isinstance(ast.op, USub): simple_expr = _simplify_ast(ast.operand) if isinstance(simple_expr, UnaryOp) and isinstance( simple_expr.op, USub): # Case --x return _simplify_ast(simple_expr.operand) elif isinstance(simple_expr, Constant): if simple_expr.value == 0: return Constant(value=0) else: return Constant(value=-simple_expr.value) else: return UnaryOp(op=USub(), operand=simple_expr) elif isinstance(ast, UnaryOp) and isinstance(ast.op, UAdd): simple_expr = _simplify_ast(ast.operand) return simple_expr elif isinstance(ast, list): simple_list = [_simplify_ast(elem) for elem in ast] return simple_list elif isinstance(ast, tuple): return tuple(_simplify_ast(list(ast))) elif ast.__class__ in AST._node_attrs: # Handle node types with no special cases. for attr_name in AST._node_attrs[ast.__class__]: attr = getattr(ast, attr_name) if isinstance(attr, list): for ii, elem in enumerate(attr): attr[ii] = _simplify_ast(elem) else: setattr(ast, attr_name, _simplify_ast(attr)) return ast else: return ast
def _diff_ast(ast, wrt): """ Return an AST that is the derivative of ast with respect the variable with name 'wrt'. """ # For now, the strategy is to return the most general forms, and let # the simplifier take care of the special cases. if isinstance(ast, Name): if ast.id == wrt: return _ONE else: return _ZERO elif isinstance(ast, Constant): return _ZERO elif isinstance(ast, BinOp) and (isinstance(ast.op, Add) or isinstance(ast.op, Sub)): # Just take the derivative of the arguments. The call to ast.__class__ # lets us use the same code from Add and Sub. return (BinOp(left=_diff_ast(ast.left, wrt), op=ast.op, right=_diff_ast(ast.right, wrt))) elif isinstance(ast, BinOp) and (isinstance(ast.op, Mult) or isinstance(ast.op, Div)): # Collect all the numerators and denominators together nums, denoms = [], [] AST._collect_num_denom(ast, nums, denoms) # Collect the numerator terms into a single AST num = AST._make_product(nums) # Take the derivative of the numerator terms as a product num_d = _product_deriv(nums, wrt) if not denoms: # If there is no denominator return num_d denom = AST._make_product(denoms) denom_d = _product_deriv(denoms, wrt) # Derivative of x/y is x'/y + -x*y'/y**2 term1 = BinOp(left=num_d, op=Div(), right=denom) term2 = BinOp(left=BinOp(left=UnaryOp(op=USub(), operand=num), op=Mult(), right=denom_d), op=Div(), right=BinOp(left=denom, op=Pow(), right=Constant(value=2))) return BinOp(left=term1, op=Add(), right=term2) elif isinstance(ast, BinOp) and isinstance(ast.op, Pow): # Use the derivative of the 'pow' function ast = Call(func=Name(id='pow', ctx=Load()), args=[ast.left, ast.right]) return _diff_ast(ast, wrt) elif isinstance(ast, Call): func_name = AST.ast2str(ast.func) args = ast.args args_d = [_diff_ast(arg, wrt) for arg in args] if (func_name, len(args)) in _KNOWN_FUNCS: form = copy.deepcopy(_KNOWN_FUNCS[(func_name, len(args))]) else: # If this isn't a known function, our form is # (f_0(args), f_1(args), ...) args_expr = [ Name(id='arg%i' % ii, ctx=Load()) for ii in range(len(args)) ] form = [ Call(func=Name(id='%s_%i' % (func_name, ii), ctx=Load()), args=args_expr, keywords=[]) for ii in range(len(args)) ] # We build up the terms in our derivative # f_0(x,y)*x' + f_1(x,y)*y', etc. outs = [] for arg_d, arg_form_d in zip(args_d, form): # We skip arguments with 0 derivative if arg_d == _ZERO: continue for ii, arg in enumerate(args): Substitution._sub_subtrees_for_vars(arg_form_d, {'arg%i' % ii: arg}) outs.append(BinOp(left=arg_form_d, op=Mult(), right=arg_d)) # If all arguments had zero deriviative if not outs: return _ZERO else: # We add up all our terms ret = outs[0] for term in outs[1:]: ret = BinOp(left=ret, op=Add(), right=term) return ret elif isinstance(ast, UnaryOp) and isinstance(ast.op, USub): return UnaryOp(op=USub(), operand=_diff_ast(ast.operand, wrt)) elif isinstance(ast, UnaryOp) and isinstance(ast.op, UAdd): return UnaryOp(op=UAdd(), operand=_diff_ast(ast.operand, wrt))
def _ast2TeX(ast, outer=AST._FARTHEST_OUT, name_dict={}, adjust=0): """ Return a TeX version of an AST. outer: The AST's 'parent' node, used to determine whether or not to enclose the result in parentheses. The default of _FARTHEST_OUT will never enclose the result in parentheses. name_dict: A dictionary mapping variable names used in the expression to preferred TeX expressions. adjust: A numerical value to adjust the priority of this ast for particular cases. For example, the denominator of a '/' needs parentheses in more cases than does the numerator. """ if isinstance(ast, Name): # Try to get a value from the name_dict, defaulting to ast.name if # ast.name isn't in name_dict out = name_dict.get(ast.id, ast.id) elif isinstance(ast, Constant): out = str(ast.value) elif isinstance(ast, BinOp) and isinstance(ast.op, Add): out = '%s + %s' % (_ast2TeX( ast.left, ast, name_dict), _ast2TeX(ast.right, ast, name_dict)) elif isinstance(ast, BinOp) and isinstance(ast.op, Sub): out = '%s - %s' % (_ast2TeX(ast.left, ast, name_dict), _ast2TeX(ast.right, ast, name_dict, adjust=1)) elif isinstance(ast, BinOp) and (isinstance(ast.op, Mult) or isinstance(ast.op, Div)): # We collect all terms numerator and denominator nums, denoms = [], [] AST._collect_num_denom(ast, nums, denoms) # _EMPTY_MUL ensures that parentheses are done properly, since every # element is now the child of a Mul lam_func = lambda arg: _ast2TeX(arg, _EMPTY_MUL, name_dict) nums = [lam_func(term) for term in nums] if denoms: denoms = [lam_func(term) for term in denoms] out = r'\frac{%s}{%s}' % (r' \cdot '.join(nums), r' \cdot '.join(denoms)) else: out = r' \cdot '.join(nums) elif isinstance(ast, BinOp) and isinstance(ast.op, Pow): out = '{%s}^{%s}' % (_ast2TeX(ast.left, ast, name_dict, adjust=1), _ast2TeX(ast.right, ast, name_dict)) elif isinstance(ast, UnaryOp) and isinstance(ast.op, USub): out = '-%s' % _ast2TeX(ast.operand, ast, name_dict) elif isinstance(ast, UnaryOp) and isinstance(ast.op, UAdd): out = '+%s' % _ast2TeX(ast.operand, ast, name_dict) elif isinstance(ast, Call): lam_func = lambda arg: _ast2TeX(arg, name_dict=name_dict) name = lam_func(ast.func) args = [lam_func(arg) for arg in ast.args] if name == 'sqrt' and len(args) == 1: # Special case out = r'\sqrt{%s}' % args[0] else: out = r'\operatorname{%s}\left(%s\right)' % (name, r',\,'.join(args)) elif isinstance(ast, BoolOp) and isinstance(ast.op, Or) or isinstance( ast.op, And): out = r'\operatorname{%s}' % (str(ast)) if AST._need_parens(outer, ast, adjust): return out else: return r'\left(%s\right)' % out
def get_count_from_ast(ast_list, term): count = 0 for item in ast_list: if AST.ast2str(item) == AST.ast2str(term): count += 1 return count
def _sub_subtrees_for_comps(ast, ast_mappings): if isinstance(ast, Compare) and ast2str(ast) in ast_mappings: return ast_mappings[ast2str(ast)] ast = AST.recurse_down_tree(ast, _sub_subtrees_for_comps, (ast_mappings,)) return ast