def test_afterSubMult(self): 'Tests after SubToMult pre-processing' tests = [("1 + 2 - 3", ast.BoolOp(ast.Add(), [ ast.Num(1), ast.Num(2), ast.BinOp(ast.Num(-1), ast.Mult(), ast.Num(3)) ])), ("1 + 2 - 3 + 4", ast.BoolOp(ast.Add(), [ ast.Num(1), ast.Num(2), ast.BinOp(ast.Num(-1), ast.Mult(), ast.Num(3)), ast.Num(4) ])), ("(1 + 2) - (3 + 4)", ast.BoolOp(ast.Add(), [ ast.Num(1), ast.Num(2), ast.BinOp(ast.Num(-1), ast.Mult(), ast.BinOp(ast.Num(3), ast.Add(), ast.Num(4))) ]))] for teststring, ref_ast in tests: test_ast = ast.parse(teststring, mode="eval").body test_ast = pre_processing.all_preprocessings(test_ast) test_ast = Flattening(ast.Add).visit(test_ast) self.assertTrue(Comparator().visit(test_ast, ref_ast))
def test_differentops(self): 'Test with other types of operators' tests = [("(3 & 5 & 6)", ast.BoolOp(ast.BitAnd(), [ast.Num(3), ast.Num(5), ast.Num(6)])), ("(1 ^ 2 ^ 3) - 4", ast.BinOp(ast.BoolOp(ast.BitXor(), [ast.Num(1), ast.Num(2), ast.Num(3)]), ast.Add(), ast.BinOp(ast.Num(-1), ast.Mult(), ast.Num(4)))), ("((1 + 2 + 3) & (4 + 5))", ast.BinOp(ast.BoolOp(ast.Add(), [ast.Num(1), ast.Num(2), ast.Num(3)]), ast.BitAnd(), ast.BinOp(ast.Num(4), ast.Add(), ast.Num(5)))), ("(1 & 2 & 3) - (4 & 5)", ast.BinOp(ast.BoolOp(ast.BitAnd(), [ast.Num(1), ast.Num(2), ast.Num(3)]), ast.Add(), ast.BinOp(ast.Num(-1), ast.Mult(), ast.BinOp(ast.Num(4), ast.BitAnd(), ast.Num(5))))), ("(1 & 2 & 3) << (4 & 5)", ast.BinOp(ast.BoolOp(ast.BitAnd(), [ast.Num(1), ast.Num(2), ast.Num(3)]), ast.LShift(), ast.BinOp(ast.Num(4), ast.BitAnd(), ast.Num(5))))] for teststring, ref_ast in tests: test_ast = ast.parse(teststring, mode="eval").body test_ast = pre_processing.all_preprocessings(test_ast) test_ast = asttools.LevelOperators().visit(test_ast) self.assertTrue(asttools.Comparator().visit(test_ast, ref_ast))
def _generate_set_call_body(self, external_function_name, arg_length, strart_idx, target_ids): arg_ids = np.arange(arg_length) targets = [] args = [] for target_id in target_ids: targets.append(ast.Subscript(value=GLOBAL_ARRAY, slice=ast.Index( value=ast.BinOp( left=ast.Constant(value=target_id + strart_idx), op=ast.Add(), right=ast.BinOp(left=ast.Constant(value=arg_length, kind=None), op=ast.Mult(), right=ast.Name(id='i', ctx=ast.Load())))), ctx=ast.Store())) for arg_id in arg_ids: args.append( ast.Subscript(value=GLOBAL_ARRAY, slice=ast.Index(value=ast.BinOp( left=ast.Constant(value=arg_id + strart_idx), op=ast.Add(), right=ast.BinOp(left=ast.Constant(value=arg_length, kind=None), op=ast.Mult(), right=ast.Name(id='i', ctx=ast.Load())))), ctx=ast.Load)) if len(targets) > 1: return ast.Assign(targets=[ast.Tuple(elts=targets)], value=ast.Call(func=ast.Name(id=external_function_name, ctx=ast.Load()), args=args, keywords=[]), lineno=0) else: return ast.Assign(targets=[targets[0]], value=ast.Call(func=ast.Name(id=external_function_name, ctx=ast.Load()), args=args, keywords=[]), lineno=0)
def operator(mod): op = ast.Add() if mod == 'or': op = ast.Or() if mod == '|': op = ast.Or() if mod == '||': op = ast.Or() if mod == 'and': op = ast.And() if mod == '&': op = ast.And() if mod == '&&': op = ast.And() if mod == 'plus': op = ast.Add() if mod == '+': op = ast.Add() if mod == '-': op = ast.Sub() if mod == 'minus': op = ast.Sub() if mod == 'times': op = ast.Mult() if mod == '*': op = ast.Mult() if mod == '**': op = ast.Pow() if mod == 'divide': op = ast.Div() if mod == 'divided': op = ast.Div() if mod == 'divided by': op = ast.Div() if mod == '/': op = ast.Div() if mod == '//': op = ast.FloorDiv() if mod == 'floor div': op = ast.FloorDiv() if mod == '%': op = ast.Mod() if mod == 'mod': op = ast.Mod() if mod == 'modulus': op = ast.Mod() if mod == 'modulo': op = ast.Mod() if mod == '^': op = ast.BitXor() if mod == 'xor': op = ast.BitXor() if mod == '<<': op = ast.LShift() if mod == '>>': op = ast.RShift() return op
def test_BinOp(self): for node, op in self.operators.items(): self.verify(ast.BinOp(ast.Num(2), node(), ast.Num(3)), '2{}3'.format(op)) # 1 + 2 * 3 = BinOp(2 + BinOp(2 * 3)) mult = ast.BinOp(ast.Num(2), ast.Mult(), ast.Num(3)) expr = ast.BinOp(ast.Num(1), ast.Add(), mult) self.verify(expr, '1+2*3') # (1 + 2) * 3 = BinOp(BinOp(1 + 2) * 3) add = ast.BinOp(ast.Num(1), ast.Add(), ast.Num(2)) expr = ast.BinOp(add, ast.Mult(), ast.Num(3)) self.verify(expr, '(1+2)*3') # 2 * 3 + 1 = BinOp(BinOp(2 * 3) + 1) expr = ast.BinOp(mult, ast.Add(), ast.Num(1)) self.verify(expr, '2*3+1') # 3 * (1 + 2) = BinOp(3 * BinOp(1 + 2)) expr = ast.BinOp(ast.Num(3), ast.Mult(), add) self.verify(expr, '3*(1+2)') # 3 - (1 + 2) = BinOp(3 - (BinOp1 + 2)) expr = ast.BinOp(ast.Num(3), ast.Sub(), add) self.verify(expr, '3-(1+2)') # Deal with Pow's "special" precedence compared to unary operators. self.verify(ast.BinOp(ast.Num(-1), ast.Pow(), ast.Num(2)), '(-1)**2') self.verify( ast.UnaryOp(ast.USub(), ast.BinOp(ast.Num(1), ast.Pow(), ast.Num(2))), '-1**2') self.verify( ast.BinOp(ast.Num(1), ast.Pow(), ast.UnaryOp(ast.USub(), ast.Num(2))), '1**(-2)')
def exponent(a, b): """ return ast.BinOp( left=ast.BinOp( left=b, right=ast.BinOp(left=a, right=ast.Num(n=b.n - 1), op=ast.Pow() ), op=ast.Mult() ), right=derive(a), op=ast.Mult() ) """ #print(a, b) rs = ast.BinOp( left=ast.Num(n=e), right=ast.BinOp(left=ast.Call(ast.Name(id='log', ctx=ast.Load()), [a], []), right=b, op=ast.Mult()), op=ast.Pow(), ) return ast.BinOp(left=multiply( ast.Call( ast.Name(id='log', ctx=ast.Load()), [a], [], ), b), right=rs, op=ast.Mult()) return f"(b * (a ** (b - 1)) * {derive(a)})"
def derive(expr): #print(expr) if isinstance(expr, ast.BinOp): return operations[type(expr.op)](expr.left, expr.right) if isinstance(expr, ast.Name): #print("its a name", expr.id, expr.id == "x") if expr.id == "x": return ast.Num(n=1) else: return ast.Num(n=0) if isinstance(expr, ast.Num): return integer(expr) if isinstance(expr, ast.Call): fdrv = names.get(expr.func.id) if callable(fdrv): #print(fdrv) return ast.BinOp(left=fdrv(*expr.args), right=derive(expr.args[0]), op=ast.Mult()) elif isinstance(fdrv, str): return ast.BinOp(left=ast.Call(ast.Name(id=fdrv, ctx=ast.Load()), expr.args, []), right=derive(expr.args[0]), op=ast.Mult()) elif isinstance(fdrv, ast.AST): return ast.BinOp(left=fdrv, right=derive(expr.args[0]), op=ast.Mult()) else: #print("err! err!") raise RuntimeError( f"{expr.func.id} does not have a defined derivative!")
def visit_BinOp(self, node): 'Change operator - to a *(-1)' self.generic_visit(node) if isinstance(node.op, ast.Sub): node.op = ast.Add() cond_mult = (isinstance(node.right, ast.BinOp) and isinstance(node.right.op, ast.Mult)) if cond_mult: if isinstance(node.right.left, ast.Num): coeff = node.right.left operand = node.right.right elif isinstance(node.right.right, ast.Num): coeff = node.right.right operand = node.right.left else: node.right = ast.BinOp(ast.Num(-1), ast.Mult(), node.right) return node # trying to "simplify" constant coeffs if possible if self.nbits: if (-coeff.n) % 2**self.nbits == 1: node.right = operand else: coeff.n = -coeff.n % 2**self.nbits else: coeff.n = -coeff.n else: node.right = ast.BinOp(ast.Num(-1), ast.Mult(), node.right) return node
def compile_multiply(p): if len(p) == 2: return build_ast(p[1]) elif len(p) == 3: return ast.BinOp(build_ast(p[1]), ast.Mult(), build_ast(p[2])) else: return ast.BinOp(compile_multiply(p[:-1]), ast.Mult(), build_ast(p[-1]))
def visit_BinOp(self, node): # operator = Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift | RShift | BitOr | BitXor | BitAnd | FloorDiv print(" in MyTransformer.visit_BinOp()") print(" curr op =", node.op) bin_negate = node.op # use pseudorandomness to determine whether to negate or just mix up rand_num = random.randint(1, 10) # negate if rand_num >= 4: print(" negating...") if isinstance(node.op, ast.Add): bin_negate = ast.Sub() elif isinstance(node.op, ast.Sub): bin_negate = ast.Add() elif isinstance(node.op, ast.Mult): bin_negate = ast.Div() elif isinstance(node.op, ast.Div): bin_negate = ast.FloorDiv() elif isinstance(node.op, ast.FloorDiv): bin_negate = ast.Div() elif isinstance(node.op, ast.LShift): bin_negate = ast.RShift() elif isinstance(node.op, ast.RShift): bin_negate = ast.LShift() elif isinstance(node.op, ast.BitOr): bin_negate = ast.BitAnd() elif isinstance(node.op, ast.BitAnd): bin_negate = ast.BitXor() elif isinstance(node.op, ast.BitXor): bin_negate = ast.BitOr() elif isinstance(node.op, ast.Pow): bin_negate = ast.Mult() elif isinstance(node.op, ast.Mod): bin_negate = ast.Div() elif isinstance(node.op, ast.MatMult): bin_negate = ast.Mult() else: print(" did not find negation for", node.op) # mix up else: print(" mixing up...") if isinstance(node.op, ast.Add): bin_negate = ast.Mult() elif isinstance(node.op, ast.Sub): bin_negate = ast.Div() elif isinstance(node.op, ast.Mult): bin_negate = ast.Pow() elif isinstance(node.op, ast.Div): bin_negate = ast.FloorDiv() elif isinstance(node.op, ast.FloorDiv): bin_negate = ast.Div() elif isinstance(node.op, ast.BitOr): bin_negate = ast.BitXor() elif isinstance(node.op, ast.BitAnd): bin_negate = ast.BitOr() elif isinstance(node.op, ast.BitXor): bin_negate = ast.BitOr() elif isinstance(node.op, ast.Pow): bin_negate = ast.Mult() elif isinstance(node.op, ast.Mod): bin_negate = ast.FloorDiv() else: print(" did not find negation for", node.op) print(" bin_negate =", bin_negate) # create negated node | BinOp(expr left, operator op, expr right) new_node = node new_node.op = bin_negate ast.copy_location(new_node, node) ast.fix_missing_locations(new_node) return new_node
def divide(a, b): db = derive(b) leftmul = ast.BinOp(left=a, right=db, op=ast.Mult()) rightmul = ast.BinOp(left=b, right=derive(a), op=ast.Mult()) topsub = ast.BinOp(left=leftmul, right=rightmul, op=ast.Sub()) bottompow = ast.BinOp(left=db, right=ast.Num(n=2), op=ast.Pow()) return ast.BinOp(left=topsub, right=bottompow, op=ast.Div()) return f"(({a} * {db} - {b} * {derive(a)}) / ({db} ** 2))"
def mutation_visit(self, node): replacement = { type(ast.Add()): ast.Sub(), type(ast.Sub()): ast.Add(), type(ast.Mult()): ast.Div(), type(ast.Div()): ast.Mult() } try: node.op = replacement[type(node.op)] except KeyError: pass # All other binary operators (and, mod, etc.) return node
def generateFermiSplitFunction(funcName, tanhModuleName="math"): marginArgName = "margin" lesArgName = "les" lesArg = ast.Name(id=lesArgName) greatArgName = "gret" valueArgName = "val" yield ast.ImportFrom(module=tanhModuleName, names=[ast.alias(name=tanhName, asname=None)], level=0) yield ast.Assign(targets=[globalInvTemp], value=astNum(n=100)) yield ast.FunctionDef( name=funcName, args=ast.arguments(args=[ ast.arg(arg=valueArgName, annotation=floatType), ast.arg(arg=marginArgName, annotation=floatType), ast.arg(arg=lesArgName, annotation=floatType), ast.arg(arg=greatArgName, annotation=floatType), ast.arg(arg=inverseTemperatureArgName, annotation=floatType) ], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[ ast.Return(value=ast.BinOp( left=lesArg, op=ast.Add(), right=ast.BinOp(left=ast.BinOp( left=astNum(n=0.5), op=ast.Mult(), right=ast.BinOp(left=tanhCall( ast.BinOp(left=ast.BinOp(left=ast.Name( id=valueArgName), op=ast.Sub(), right=ast.Name( id=marginArgName)), op=ast.Mult(), right=globalInvTemp)), op=ast.Add(), right=astNum(n=1))), op=ast.Mult(), right=ast.BinOp(left=ast.Name(id=greatArgName), op=ast.Sub(), right=lesArg)))) ], decorator_list=[], returns=floatType)
def test_noflattening(self): 'Tests where nothing should be flattened' corresp = [(["a + b", "b + a"], ast.BinOp(ast.Name('a', ast.Load()), ast.Add(), ast.Name('b', ast.Load()))), (["c*d", "d*c"], ast.BinOp(ast.Name('c', ast.Load()), ast.Mult(), ast.Name('d', ast.Load()))), (["a + c*d", "d*c + a"], ast.BinOp( ast.Name('a', ast.Load()), ast.Add(), ast.BinOp(ast.Name('c', ast.Load()), ast.Mult(), ast.Name('d', ast.Load()))))] for refstring, result in corresp: self.generic_flattening(refstring, result)
def __init__(self, base_node): BaseMutator.__init__(self, base_node) self.original_bin_op = base_node.op if type(base_node.op) in [ ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Mod, ast.Pow ]: if type(base_node.op) is ast.Add: # Don't perform the mutation for string concatenation (e.g. 'string' + 'concat') if (type(base_node.left) is not ast.Str) and (type( base_node.right) is not ast.Str): self.mutations.append({"op": ast.Sub()}) if type(base_node.op) is ast.Sub: self.mutations.append({"op": ast.Add()}) if type(base_node.op) is ast.Mult: # Don't perform the mutation for string repetition (e.g. 'string' * 50) if (type(base_node.left) is not ast.Str) and (type( base_node.right) is not ast.Str): self.mutations.append({"op": ast.Div()}) if type(base_node.op) is ast.Div: self.mutations.append({"op": ast.Mult()}) if type(base_node.op) is ast.Mod: # Don't perform the mutation for string format (e.g. 'strings are %s' % 'cool') if (type(base_node.left) is not ast.Str) and (type( base_node.right) is not ast.Str): self.mutations.append({"op": ast.Pow()}) if type(base_node.op) is ast.Pow: self.mutations.append({"op": ast.Mod()})
def get_bin_op(self, s): """Get the BinOp class for s.""" op = None if s == '+': op = ast.Add() elif s == '-': op = ast.Sub() elif s == '*': op = ast.Mult() elif s == '/': op = ast.Div() elif s == '%': op = ast.Mod() elif s == '<<': op = ast.LShift() elif s == '>>': op = ast.RShift() elif s == '&': op = ast.BitAnd() elif s == '^': op = ast.BitXor() elif s == '|': op = ast.BitOr() return op
def _mutate_index_dim(self, gid, ctype, node, axis=0): info = cast.CName('cly_%s_info' % gid, ast.Load(), ctype.array_info) right = cast.CAttribute(info, 's%s' % hex(axis + 4)[2:], ast.Load(), derefrence(ctype.array_info)) index = cast.CBinOp(node, ast.Mult(), right, node.ctype) #FIXME: cast type return index
def test_onBoolOp(self): 'Tests on BoolOp' expr_a = ast.BoolOp(ast.Add(), [ast.Num(1), ast.Num(2), ast.Num(3)]) expr_b = ast.BoolOp(ast.Add(), [ast.Num(3), ast.Num(2), ast.Num(1)]) self.assertTrue(asttools.Comparator().visit(expr_a, expr_b)) expr_a = ast.BoolOp(ast.Add, [ast.Num(1), ast.BoolOp(ast.Mult(), [ast.Num(5), ast.Num(6)]), ast.Num(4)]) expr_b = ast.BoolOp(ast.Add, [ast.BoolOp(ast.Mult(), [ast.Num(6), ast.Num(5)]), ast.Num(4), ast.Num(1)]) self.assertTrue(asttools.Comparator().visit(expr_a, expr_b))
def gen_flat_index(idxs, shape): flat_idx = idxs[0] for i in range(len(idxs[1:])): flat_idx = ast.BinOp( ast.BinOp(flat_idx, ast.Mult(), ast.Num(shape[i + 1])), ast.Add(), idxs[i + 1]) return flat_idx
def visit_Operator(self, node: Operator, *args, **kwargs) -> C.operator: if node == Operator.Add: return C.Add() elif node == Operator.Sub: return C.Sub() elif node == Operator.Mult: return C.Mult() elif node == Operator.MatMult: return C.MatMult() elif node == Operator.Div: return C.Div() elif node == Operator.Mod: return C.Mod() elif node == Operator.Pow: return C.Pow() elif node == Operator.LShift: return C.LShift() elif node == Operator.RShift: return C.RShift() elif node == Operator.BitOr: return C.BitOr() elif node == Operator.BitXor: return C.BitXor() elif node == Operator.BitAnd: return C.BitAnd() elif node == Operator.FloorDiv: return C.FloorDiv() else: raise Exception(f'unknown Operator {node!r}')
def binop_action(s, loc, tokens): node = tokens[0] for op_char, right in tokens[1:]: if op_char == '+': op = ast.Add() elif op_char == '-': op = ast.Sub() elif op_char == '*': op = ast.Mult() elif op_char == '/': op = ast.Div() elif op_char == '%': op = ast.Mod() elif op_char == '<<': op = ast.LShift() elif op_char == '>>': op = ast.RShift() elif op_char == '&': op = ast.BitAnd() elif op_char == '^': op = ast.BitXor() else: # op_char == '|': op = ast.BitOr() node = ast.BinOp(left=node, op=op, right=right, lineno=1, col_offset=0) return node
def mathTransform(self, attribute, args): # operations from the math python module or builtin operations that exist in JavaScript Math: if attribute in self.directMathOperations: func = ast.Attribute(value=ast.Name(id='Math', ctx=ast.Load()), attr=attribute, ctx=ast.Load()) return ast.Call(func=func, args=args, keywords=[]) # substitutable operations, e.g. a = sum(b,c) => a = [b,c].reduce( function(x,y) { return x+y: }) elif attribute in self.subtitutableOperations: # a = sum(b,c) => a = [b,c].reduce( function(x,y) { return x+y: }) if attribute == 'sum': func = ast.Attribute(value=ast.List(elts=args, ctx=ast.Load()), attr='reduce', ctx=ast.Load()) args = [ ast.Call( func=ast.Name(id='JS', ctx=ast.Load()), args=[ast.Str(s=' function(x,y) { return x+y; }')], keywords=[]) ] return ast.Call(func=func, args=args, keywords=[]) # randint(a,b) => Math.floor(Math.random() * (b - a + 1)) + a elif attribute == 'randint': left = ast.Call( func=ast.Attribute(value=ast.Name(id='Math', ctx=ast.Load()), attr='floor', ctx=ast.Load()), args=[ ast.BinOp(left=ast.Call(func=ast.Attribute( value=ast.Name(id='Math', ctx=ast.Load()), attr='random', ctx=ast.Load()), args=[], keywords=[]), op=ast.Mult(), right=ast.BinOp(left=ast.BinOp( left=args[1], op=ast.Sub(), right=args[0]), op=ast.Add(), right=ast.Num(n=1))) ], keywords=[]) right = args[0] return ast.BinOp(left=left, op=ast.Add(), right=right) else: return None
def scaler_adder_xform(node, scaler, adder): """Returns an ast for the form (node+adder)*scaler""" if adder != 0.0: newnode = ast.BinOp(node, ast.Add(), ast.Num(adder)) else: newnode = node if scaler != 1.0: # do the add and the mult newnode = ast.BinOp(newnode, ast.Mult(), ast.Num(scaler)) return ast.copy_location(newnode, node)
def visit_BinOp(self, node: ast.BinOp): if isinstance(node.op, ast.Mult): return ast.BinOp(self.visit(node.left), ast.Add(), self.visit(node.right)) elif isinstance(node.op, ast.Add): return ast.BinOp(self.visit(node.left), ast.Mult(), self.visit(node.right)) else: return node
def transform(tree, **kw): if type(tree) is ast.Num: return ast.Str(s = str(tree.n)) if type(tree) is ast.Str: return ast.Num(n = int(tree.s)) if type(tree) is ast.BinOp and type(tree.op) is ast.Mult: return ast.BinOp(tree.left, ast.Add(), tree.right) if type(tree) is ast.BinOp and type(tree.op) is ast.Add: return ast.BinOp(tree.left, ast.Mult(), tree.right)
def visit_BinOp(self, node): 'Change left shifts into multiplications' if not isinstance(node.op, ast.LShift): return self.generic_visit(node) if isinstance(node.right, ast.Num): self.generic_visit(node) return ast.BinOp(node.left, ast.Mult(), ast.Num(2**(node.right.n))) return self.generic_visit(node)
def 二元表达式(片段): 左 = 片段[0] 右 = 片段[2] 运算符 = 片段[1].getstr() 对照表 = {'+': ast.Add(), '-': ast.Sub(), '*': ast.Mult()} if 运算符 in 对照表: python运算 = 对照表[运算符] else: breakpoint() return 语法树.二元运算(左, python运算, 右, 片段)
def visit_UnaryOp(self, node): 'Change -x to (-1)*x' self.generic_visit(node) if isinstance(node.op, ast.USub): ope = node.operand cond_mult = (isinstance(ope, ast.BinOp) and isinstance(ope.op, ast.Mult)) if cond_mult: if isinstance(ope.left, ast.Num): node = ast.BinOp(ast.Num(-ope.left.n), ast.Mult(), ope.right) elif isinstance(ope.right, ast.Num): node = ast.BinOp(ope.left, ast.Mult(), ast.Num(-ope.right.n)) else: node = ast.BinOp(ast.Num(-1), ast.Mult(), ope) else: node = ast.BinOp(ast.Num(-1), ast.Mult(), ope) return node
def visit_UnaryOp(self, node): 'Change USub and Invert' operand = self.visit(node.operand) if isinstance(node.op, ast.UAdd): return operand if isinstance(node.op, ast.USub): return ast.BinOp(ast.Num(-1), ast.Mult(), operand) if isinstance(node.op, ast.Invert): return ast.BinOp(ast.Num(-1), ast.BitXor(), operand) assert False, 'unhandled node type: ' + ast.dump(node)
def operator(value): if value == "+": return ast.Add() elif value == "-": return ast.Sub() elif value == "*": return ast.Mult() elif value == "/": return ast.Div() raise Exception("operator not supported: {0}".format(value))