def test_BinOp(self): for node, op in self.operators.items(): self.verify(ast.BinOp(ast.Num(2), node(), ast.Num(3)), '2{}3'.format(op)) # 1 + 2 * 3 = BinOp(2 + BinOp(2 * 3)) mult = ast.BinOp(ast.Num(2), ast.Mult(), ast.Num(3)) expr = ast.BinOp(ast.Num(1), ast.Add(), mult) self.verify(expr, '1+2*3') # (1 + 2) * 3 = BinOp(BinOp(1 + 2) * 3) add = ast.BinOp(ast.Num(1), ast.Add(), ast.Num(2)) expr = ast.BinOp(add, ast.Mult(), ast.Num(3)) self.verify(expr, '(1+2)*3') # 2 * 3 + 1 = BinOp(BinOp(2 * 3) + 1) expr = ast.BinOp(mult, ast.Add(), ast.Num(1)) self.verify(expr, '2*3+1') # 3 * (1 + 2) = BinOp(3 * BinOp(1 + 2)) expr = ast.BinOp(ast.Num(3), ast.Mult(), add) self.verify(expr, '3*(1+2)') # 3 - (1 + 2) = BinOp(3 - (BinOp1 + 2)) expr = ast.BinOp(ast.Num(3), ast.Sub(), add) self.verify(expr, '3-(1+2)') # Deal with Pow's "special" precedence compared to unary operators. self.verify(ast.BinOp(ast.Num(-1), ast.Pow(), ast.Num(2)), '(-1)**2') self.verify( ast.UnaryOp(ast.USub(), ast.BinOp(ast.Num(1), ast.Pow(), ast.Num(2))), '-1**2') self.verify( ast.BinOp(ast.Num(1), ast.Pow(), ast.UnaryOp(ast.USub(), ast.Num(2))), '1**(-2)')
def visit_Num(self, node): num = node.n if not isinstance(num, int): return node if num >= 10**5 and not math.log10(num) % 1: power_10 = int(math.log10(num)) return ast.BinOp(ast.Num(10), ast.Pow(), ast.Num(power_10)) elif num >= 2**17 and not math.log2(num) % 1: power_2 = int(math.log2(num)) return ast.BinOp(ast.Num(2), ast.Pow(), ast.Num(power_2)) else: return node
def power(self) -> ast.expr: left = self.await_() while self.match_(TokenType.STAR_STAR): right = self.await_() left = ast.BinOp(left, ast.Pow(), right, **self.get_loc(left, right)) return left
def visit_Operator(self, node: Operator, *args, **kwargs) -> C.operator: if node == Operator.Add: return C.Add() elif node == Operator.Sub: return C.Sub() elif node == Operator.Mult: return C.Mult() elif node == Operator.MatMult: return C.MatMult() elif node == Operator.Div: return C.Div() elif node == Operator.Mod: return C.Mod() elif node == Operator.Pow: return C.Pow() elif node == Operator.LShift: return C.LShift() elif node == Operator.RShift: return C.RShift() elif node == Operator.BitOr: return C.BitOr() elif node == Operator.BitXor: return C.BitXor() elif node == Operator.BitAnd: return C.BitAnd() elif node == Operator.FloorDiv: return C.FloorDiv() else: raise Exception(f'unknown Operator {node!r}')
def operator(mod): op = ast.Add() if mod == 'or': op = ast.Or() if mod == '|': op = ast.Or() if mod == '||': op = ast.Or() if mod == 'and': op = ast.And() if mod == '&': op = ast.And() if mod == '&&': op = ast.And() if mod == 'plus': op = ast.Add() if mod == '+': op = ast.Add() if mod == '-': op = ast.Sub() if mod == 'minus': op = ast.Sub() if mod == 'times': op = ast.Mult() if mod == '*': op = ast.Mult() if mod == '**': op = ast.Pow() if mod == 'divide': op = ast.Div() if mod == 'divided': op = ast.Div() if mod == 'divided by': op = ast.Div() if mod == '/': op = ast.Div() if mod == '//': op = ast.FloorDiv() if mod == 'floor div': op = ast.FloorDiv() if mod == '%': op = ast.Mod() if mod == 'mod': op = ast.Mod() if mod == 'modulus': op = ast.Mod() if mod == 'modulo': op = ast.Mod() if mod == '^': op = ast.BitXor() if mod == 'xor': op = ast.BitXor() if mod == '<<': op = ast.LShift() if mod == '>>': op = ast.RShift() return op
def exponent(a, b): """ return ast.BinOp( left=ast.BinOp( left=b, right=ast.BinOp(left=a, right=ast.Num(n=b.n - 1), op=ast.Pow() ), op=ast.Mult() ), right=derive(a), op=ast.Mult() ) """ #print(a, b) rs = ast.BinOp( left=ast.Num(n=e), right=ast.BinOp(left=ast.Call(ast.Name(id='log', ctx=ast.Load()), [a], []), right=b, op=ast.Mult()), op=ast.Pow(), ) return ast.BinOp(left=multiply( ast.Call( ast.Name(id='log', ctx=ast.Load()), [a], [], ), b), right=rs, op=ast.Mult()) return f"(b * (a ** (b - 1)) * {derive(a)})"
def __init__(self, base_node): BaseMutator.__init__(self, base_node) self.original_bin_op = base_node.op if type(base_node.op) in [ ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Mod, ast.Pow ]: if type(base_node.op) is ast.Add: # Don't perform the mutation for string concatenation (e.g. 'string' + 'concat') if (type(base_node.left) is not ast.Str) and (type( base_node.right) is not ast.Str): self.mutations.append({"op": ast.Sub()}) if type(base_node.op) is ast.Sub: self.mutations.append({"op": ast.Add()}) if type(base_node.op) is ast.Mult: # Don't perform the mutation for string repetition (e.g. 'string' * 50) if (type(base_node.left) is not ast.Str) and (type( base_node.right) is not ast.Str): self.mutations.append({"op": ast.Div()}) if type(base_node.op) is ast.Div: self.mutations.append({"op": ast.Mult()}) if type(base_node.op) is ast.Mod: # Don't perform the mutation for string format (e.g. 'strings are %s' % 'cool') if (type(base_node.left) is not ast.Str) and (type( base_node.right) is not ast.Str): self.mutations.append({"op": ast.Pow()}) if type(base_node.op) is ast.Pow: self.mutations.append({"op": ast.Mod()})
def matching_nodes(self, exprs): for i, expr in enumerate(exprs): setter = get_setter(expr) replacement = ast.BinOp( left=expr, op=ast.Pow(), right=ast.Str(s=sentinel), ) ast.fix_missing_locations(replacement) setter(replacement) try: instructions = self.compile_instructions() except SyntaxError: continue finally: setter(expr) indices = [ i for i, instruction in enumerate(instructions) if instruction.argval == sentinel ] if not indices: continue arg_index = only(indices) - 1 while instructions[arg_index].opname == 'EXTENDED_ARG': arg_index -= 1 if instructions[arg_index].offset == self.frame.f_lasti: yield expr
def visit_BinOp(self, node): # operator = Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift | RShift | BitOr | BitXor | BitAnd | FloorDiv print(" in MyTransformer.visit_BinOp()") print(" curr op =", node.op) bin_negate = node.op # use pseudorandomness to determine whether to negate or just mix up rand_num = random.randint(1, 10) # negate if rand_num >= 4: print(" negating...") if isinstance(node.op, ast.Add): bin_negate = ast.Sub() elif isinstance(node.op, ast.Sub): bin_negate = ast.Add() elif isinstance(node.op, ast.Mult): bin_negate = ast.Div() elif isinstance(node.op, ast.Div): bin_negate = ast.FloorDiv() elif isinstance(node.op, ast.FloorDiv): bin_negate = ast.Div() elif isinstance(node.op, ast.LShift): bin_negate = ast.RShift() elif isinstance(node.op, ast.RShift): bin_negate = ast.LShift() elif isinstance(node.op, ast.BitOr): bin_negate = ast.BitAnd() elif isinstance(node.op, ast.BitAnd): bin_negate = ast.BitXor() elif isinstance(node.op, ast.BitXor): bin_negate = ast.BitOr() elif isinstance(node.op, ast.Pow): bin_negate = ast.Mult() elif isinstance(node.op, ast.Mod): bin_negate = ast.Div() elif isinstance(node.op, ast.MatMult): bin_negate = ast.Mult() else: print(" did not find negation for", node.op) # mix up else: print(" mixing up...") if isinstance(node.op, ast.Add): bin_negate = ast.Mult() elif isinstance(node.op, ast.Sub): bin_negate = ast.Div() elif isinstance(node.op, ast.Mult): bin_negate = ast.Pow() elif isinstance(node.op, ast.Div): bin_negate = ast.FloorDiv() elif isinstance(node.op, ast.FloorDiv): bin_negate = ast.Div() elif isinstance(node.op, ast.BitOr): bin_negate = ast.BitXor() elif isinstance(node.op, ast.BitAnd): bin_negate = ast.BitOr() elif isinstance(node.op, ast.BitXor): bin_negate = ast.BitOr() elif isinstance(node.op, ast.Pow): bin_negate = ast.Mult() elif isinstance(node.op, ast.Mod): bin_negate = ast.FloorDiv() else: print(" did not find negation for", node.op) print(" bin_negate =", bin_negate) # create negated node | BinOp(expr left, operator op, expr right) new_node = node new_node.op = bin_negate ast.copy_location(new_node, node) ast.fix_missing_locations(new_node) return new_node
def matching_nodes(self, exprs): original_instructions = self.get_original_instructions() original_index = only(i for i, inst in enumerate(original_instructions) if inst.offset == self.lasti) for i, expr in enumerate(exprs): setter = get_setter(expr) replacement = ast.BinOp( left=expr, op=ast.Pow(), right=ast.Str(s=sentinel), ) ast.fix_missing_locations(replacement) setter(replacement) try: instructions = self.compile_instructions() finally: setter(expr) indices = [ i for i, instruction in enumerate(instructions) if instruction.argval == sentinel ] if not indices: continue sentinel_index = only(indices) new_index = sentinel_index - 1 assert instructions.pop(sentinel_index).opname == 'LOAD_CONST' assert instructions.pop(sentinel_index).opname == 'BINARY_POWER' if new_index != original_index: continue call_method = False if (original_instructions[new_index].opname in ('LOAD_METHOD', 'LOOKUP_METHOD') and instructions[new_index].opname == 'LOAD_ATTR'): call_method = True instructions[new_index] = original_instructions[new_index] for inst1, inst2 in zip_longest(original_instructions, instructions): if (call_method and inst1.opname == 'CALL_METHOD' and inst2.opname == 'CALL_FUNCTION'): call_method = False continue assert (inst1.opname == inst2.opname or all('JUMP_IF_' in inst.opname for inst in [inst1, inst2]) or all(inst.opname in ('JUMP_FORWARD', 'JUMP_ABSOLUTE') for inst in [inst1, inst2]) or (inst1.opname == 'PRINT_EXPR' and inst2.opname == 'POP_TOP')), (inst1, inst2, ast.dump(expr), expr.lineno, self.code.co_filename) yield expr
def visit_BinOp(self, node): if type(node.op) == type(ast.Pow()): return self.pow_node(node) self.result.append('(') self.visit(node.left) self.visit(node.op) self.visit(node.right) self.result.append(')')
def divide(a, b): db = derive(b) leftmul = ast.BinOp(left=a, right=db, op=ast.Mult()) rightmul = ast.BinOp(left=b, right=derive(a), op=ast.Mult()) topsub = ast.BinOp(left=leftmul, right=rightmul, op=ast.Sub()) bottompow = ast.BinOp(left=db, right=ast.Num(n=2), op=ast.Pow()) return ast.BinOp(left=topsub, right=bottompow, op=ast.Div()) return f"(({a} * {db} - {b} * {derive(a)}) / ({db} ** 2))"
def visit_BinOp(self, node): if type(node.op) == type(ast.Pow()): self.handle_power(node) else: self.string += "(" self.visit(node.left) self.visit(node.op) self.visit(node.right) self.string += ")"
def create_factor_node(self, factor): base_node = self.to_node(factor.base) # When no parenthesis, exponentiations are read from right to left if len(factor.exponents) != 0: last_exponent_index = len(factor.exponents) - 1 right_node = self.to_node(factor.exponents[last_exponent_index]) for i in range(last_exponent_index - 1, -1, -1): right_node = ast.BinOp(self.to_node(factor.exponents[i]), ast.Pow(), right_node) base_node = ast.BinOp(base_node, ast.Pow(), right_node) if factor.sign in ['-', 'minus']: return ast.UnaryOp(ast.USub(), base_node) elif factor.sign in ['+', 'plus']: return ast.UnaryOp(ast.UAdd(), base_node) elif factor.sign is None: return base_node else: raise
def visit_Call(self, n): if isinstance(n.func, ast.Attribute): func = n.func.attr elif isinstance(n.func, ast.Name): func = n.func.id else: func = self.visit(n.func) if func == 'sqrt': return self.visit( ast.BinOp(left=n.args[0], op=ast.Pow(), right=ast.Num(n=1 / 2))) return [{}, {}]
def operator_equals(mod): op = ast.Add() if mod == '|=': op = ast.Or() if mod == '||=': op = ast.Or() if mod == '&=': op = ast.And() if mod == '&&=': op = ast.And() if mod == '+=': op = ast.Add() if mod == '-=': op = ast.Sub() if mod == '*=': op = ast.Mult() if mod == '**=': op = ast.Pow() if mod == '/=': op = ast.Div() if mod == '//=': op = ast.FloorDiv() if mod == '%=': op = ast.Mod() if mod == '^=': op = ast.BitXor() if mod == '<<': op = ast.LShift() if mod == '>>': op = ast.RShift() return op
def power_action(s, loc, tokens): node = tokens[0] for trailer in tokens.get('trailers', ()): if isinstance(trailer, ast.Call): trailer.func = node elif isinstance(trailer, ast.Subscript): trailer.value = node else: # dot access trailer.value = node node = trailer if 'exponential' in tokens: node = ast.BinOp(left=node, op=ast.Pow(), right=tokens['exponential'][0]) return node
def visit_BinOp(self, node): if isinstance(node.op, ast.Div): new_node = ast.BinOp(left=node.left, op=ast.Mult(), right=node.right) self.show_error( node, "A house divided cannot stand", replacement=self.replace_node(node, new_node), ) elif isinstance(node.op, ast.Mult): new_node = ast.BinOp(left=node.left, op=ast.Pow(), right=node.right) self.show_error( node, "Go forth and multiply", replacement=self.replace_node(node, new_node), ) self.generic_visit(node)
def BinOp(draw, expression) -> ast.BinOp: op = draw( sampled_from([ ast.Add(), ast.Sub(), ast.Mult(), ast.Div(), ast.FloorDiv(), ast.Mod(), ast.Pow(), ast.LShift(), ast.RShift(), ast.BitOr(), ast.BitXor(), ast.BitOr(), ast.BitAnd(), ast.MatMult() ])) le = draw(lists(expression, min_size=2, max_size=2)) return ast.BinOp(le[0], op, le[1])
def AugAssign(draw): op = draw( sampled_from([ ast.Add(), ast.Sub(), ast.Mult(), ast.Div(), ast.FloorDiv(), ast.Mod(), ast.Pow(), ast.LShift(), ast.RShift(), ast.BitOr(), ast.BitXor(), ast.BitOr(), ast.BitAnd(), ast.MatMult() ])) return ast.AugAssign(target=draw(Name(ast.Store)), op=op, value=draw(expression()))
COMMENT = Comment() class Char(object): def __init__(self, str, lineno=0): self.value = str self.lineno = lineno op_ast_map = { '+': ast.Add(), '-': ast.Sub(), '*': ast.Mult(), '/': ast.Div(), '%': ast.Mod(), '**': ast.Pow(), '<<': ast.LShift(), '>>': ast.RShift(), '|': ast.BitOr(), '^^': ast.BitXor(), '&&': ast.BitAnd(), '//': ast.FloorDiv(), '==': ast.Eq(), '!=': ast.NotEq(), '<': ast.Lt(), '<=': ast.LtE(), '>': ast.Gt(), '>=': ast.GtE(), 'is': ast.Is(), 'is_not': ast.IsNot(), 'in': ast.In(),
def mutate_Mult_to_Pow(self, node): if self.should_mutate(node): return ast.Pow() raise MutationResign()
def sqr(node): return ast.BinOp(left=node, right=num(2), op=ast.Pow())
def pow(left, right): return ast.BinOp(left=left, right=right, op=ast.Pow())
class Square(Transformation): """ Replaces **2 by a call to numpy.square. >>> import ast >>> from pythran import passmanager, backend >>> node = ast.parse('a**2') >>> pm = passmanager.PassManager("test") >>> _, node = pm.apply(Square, node) >>> print pm.dump(backend.Python, node) import numpy numpy.square(a) >>> node = ast.parse('numpy.power(a,2)') >>> pm = passmanager.PassManager("test") >>> _, node = pm.apply(Square, node) >>> print pm.dump(backend.Python, node) import numpy numpy.square(a) """ POW_PATTERN = ast.BinOp(AST_any(), ast.Pow(), ast.Num(2)) POWER_PATTERN = ast.Call(ast.Attribute(ast.Name('numpy', ast.Load()), 'power', ast.Load()), [AST_any(), ast.Num(2)], [], None, None) def __init__(self): Transformation.__init__(self) def replace(self, value): self.update = self.need_import = True return ast.Call(ast.Attribute(ast.Name('numpy', ast.Load()), 'square', ast.Load()), [value], [], None, None) def visit_Module(self, node): self.need_import = False self.generic_visit(node) if self.need_import: importIt = ast.Import(names=[ast.alias(name='numpy', asname=None)]) node.body.insert(0, importIt) return node def expand_pow(self, node, n): if n == 0: return ast.Num(1) elif n == 1: return node else: node_square = self.replace(node) node_pow = self.expand_pow(node_square, n >> 1) if n & 1: return ast.BinOp(node_pow, ast.Mult(), copy.deepcopy(node)) else: return node_pow def visit_BinOp(self, node): self.generic_visit(node) if ASTMatcher(Square.POW_PATTERN).search(node): return self.replace(node.left) elif isinstance(node.op, ast.Pow) and isinstance(node.right, ast.Num): n = node.right.n if int(n) == n and n > 0: return self.expand_pow(node.left, n) else: return node else: return node def visit_Call(self, node): self.generic_visit(node) if ASTMatcher(Square.POWER_PATTERN).search(node): return self.replace(node.args[0]) else: return node
def __pow__(self, other): return BinOp(self, ast.Pow(), other)
def to_node(self): if len(self.operator) == 0: return self.left.to_node() else: return ast.BinOp(self.left.to_node(), ast.Pow(), self.right.to_node())
def __pow__(self, right): return self.apply(ast.Pow(), right)
def as_ast(dct): """See https://docs.python.org/2/library/ast.html""" if dct['ast_type'] == "Module": return ast.Module(dct["body"]) elif dct['ast_type'] == "Interactive": return ast.Interactive(dct["body"]) elif dct['ast_type'] == "Expression": return ast.Expression(dct["body"]) elif dct['ast_type'] == "Suite": return ast.Suite(dct["body"]) elif dct['ast_type'] == "FunctionDef": return ast.FunctionDef(dct["name"], dct["args"], dct["body"], dct["decorator_list"]) elif dct['ast_type'] == "ClassDef": return ast.ClassDef(dct["name"], dct["bases"], dct["body"], dct["decorator_list"]) elif dct['ast_type'] == "Return": return ast.Return(dct["value"]) elif dct['ast_type'] == "Delete": return ast.Delete(dct["targets"]) elif dct['ast_type'] == "Assign": return ast.Assign(dct["targets"], dct["value"]) elif dct['ast_type'] == "AugAssign": return ast.AugAssign(dct["target"], dct["op"], dct["value"]) elif dct['ast_type'] == "Print": return ast.Print(dct["dest"], dct["values"], dct["nl"]) elif dct['ast_type'] == "For": return ast.For(dct["target"], dct["iter"], dct["body"], dct["orelse"]) elif dct['ast_type'] == "While": return ast.While(dct["test"], dct["body"], dct["orelse"]) elif dct['ast_type'] == "If": return ast.If(dct["test"], dct["body"], dct["orelse"]) elif dct['ast_type'] == "With": return ast.With(dct["context_expr"], dct["optional_vars"], dct["body"]) elif dct['ast_type'] == "Raise": return ast.Raise(dct["type"], dct["inst"], dct["tback"]) elif dct['ast_type'] == "TryExcept": return ast.TryExcept(dct["body"], dct["handlers"], dct["orelse"]) elif dct['ast_type'] == "TryFinally": return ast.TryFinally(dct["body"], dct["finalbody"]) elif dct['ast_type'] == "Assert": return ast.Assert(dct["test"], dct["msg"]) elif dct['ast_type'] == "Import": return ast.Import(dct["names"]) elif dct['ast_type'] == "ImportFrom": return ast.ImportFrom(dct["module"], dct["names"], dct["level"]) elif dct['ast_type'] == "Exec": return ast.Exec(dct["body"], dct["globals"], dct["locals"]) elif dct['ast_type'] == "Global": return ast.Global(dct["names"]) elif dct['ast_type'] == "Expr": return ast.Expr(dct["value"]) elif dct['ast_type'] == "Pass": return ast.Pass() elif dct['ast_type'] == "Break": return ast.Break() elif dct['ast_type'] == "Continue": return ast.Continue() elif dct['ast_type'] == "BoolOp": return ast.BoolOp(dct["op"], dct["values"]) elif dct['ast_type'] == "BinOp": return ast.BinOp(dct["left"], dct["op"], dct["right"]) elif dct['ast_type'] == "UnaryOp": return ast.UnaryOp(dct["op"], dct["operand"]) elif dct['ast_type'] == "Lambda": return ast.Lambda(dct["args"], dct["body"]) elif dct['ast_type'] == "IfExp": return ast.IfExp(dct["test"], dct["body"], dct["orelse"]) elif dct['ast_type'] == "Dict": return ast.Dict(dct["keys"], dct["values"]) elif dct['ast_type'] == "Set": return ast.Set(dct["elts"]) elif dct['ast_type'] == "ListComp": return ast.ListComp(dct["elt"], dct["generators"]) elif dct['ast_type'] == "SetComp": return ast.SetComp(dct["elt"], dct["generators"]) elif dct['ast_type'] == "DictComp": return ast.DictComp(dct["key"], dct["value"], dct["generators"]) elif dct['ast_type'] == "GeneratorExp": return ast.GeneratorExp(dct["elt"], dct["generators"]) elif dct['ast_type'] == "Yield": return ast.Yield(dct["value"]) elif dct['ast_type'] == "Compare": return ast.Compare(dct["left"], dct["ops"], dct["comparators"]) elif dct['ast_type'] == "Call": return ast.Call(dct["func"], dct["args"], dct["keywords"], dct["starargs"], dct["kwargs"]) elif dct['ast_type'] == "Repr": return ast.Repr(dct["value"]) elif dct['ast_type'] == "Num": return ast.Num(dct["n"]) elif dct['ast_type'] == "Str": # Converting to ASCII return ast.Str(dct["s"].encode('ascii', 'ignore')) elif dct['ast_type'] == "Attribute": return ast.Attribute(dct["value"], dct["attr"], dct["ctx"]) elif dct['ast_type'] == "Subscript": return ast.Subscript(dct["value"], dct["slice"], dct["ctx"]) elif dct['ast_type'] == "Name": return ast.Name(dct["id"], dct["ctx"]) elif dct['ast_type'] == "List": return ast.List(dct["elts"], dct["ctx"]) elif dct['ast_type'] == "Tuple": return ast.Tuple(dct["elts"], dct["ctx"]) elif dct['ast_type'] == "Load": return ast.Load() elif dct['ast_type'] == "Store": return ast.Store() elif dct['ast_type'] == "Del": return ast.Del() elif dct['ast_type'] == "AugLoad": return ast.AugLoad() elif dct['ast_type'] == "AugStore": return ast.AugStore() elif dct['ast_type'] == "Param": return ast.Param() elif dct['ast_type'] == "Ellipsis": return ast.Ellipsis() elif dct['ast_type'] == "Slice": return ast.Slice(dct["lower"], dct["upper"], dct["step"]) elif dct['ast_type'] == "ExtSlice": return ast.ExtSlice(dct["dims"]) elif dct['ast_type'] == "Index": return ast.Index(dct["value"]) elif dct['ast_type'] == "And": return ast.And() elif dct['ast_type'] == "Or": return ast.Or() elif dct['ast_type'] == "Add": return ast.Add() elif dct['ast_type'] == "Sub": return ast.Sub() elif dct['ast_type'] == "Mult": return ast.Mult() elif dct['ast_type'] == "Div": return ast.Div() elif dct['ast_type'] == "Mod": return ast.Mod() elif dct['ast_type'] == "Pow": return ast.Pow() elif dct['ast_type'] == "LShift": return ast.LShift() elif dct['ast_type'] == "RShift": return ast.RShift() elif dct['ast_type'] == "BitOr": return ast.BitOr() elif dct['ast_type'] == "BitXor": return ast.BitXor() elif dct['ast_type'] == "BitAnd": return ast.BitAnd() elif dct['ast_type'] == "FloorDiv": return ast.FloorDiv() elif dct['ast_type'] == "Invert": return ast.Invert() elif dct['ast_type'] == "Not": return ast.Not() elif dct['ast_type'] == "UAdd": return ast.UAdd() elif dct['ast_type'] == "USub": return ast.USub() elif dct['ast_type'] == "Eq": return ast.Eq() elif dct['ast_type'] == "NotEq": return ast.NotEq() elif dct['ast_type'] == "Lt": return ast.Lt() elif dct['ast_type'] == "LtE": return ast.LtE() elif dct['ast_type'] == "Gt": return ast.Gt() elif dct['ast_type'] == "GtE": return ast.GtE() elif dct['ast_type'] == "Is": return ast.Is() elif dct['ast_type'] == "IsNot": return ast.IsNot() elif dct['ast_type'] == "In": return ast.In() elif dct['ast_type'] == "NotIn": return ast.NotIn() elif dct['ast_type'] == "comprehension": return ast.comprehension(dct["target"], dct["iter"], dct["ifs"]) elif dct['ast_type'] == "ExceptHandler": return ast.ExceptHandler(dct["type"], dct["name"], dct["body"]) elif dct['ast_type'] == "arguments": return ast.arguments(dct["args"], dct["vararg"], dct["kwarg"], dct["defaults"]) elif dct['ast_type'] == "keyword": return ast.keyword(dct["arg"], dct["value"]) elif dct['ast_type'] == "alias": return ast.alias(dct["name"], dct["asname"]) else: return dct
def matching_nodes(self, exprs): original_instructions = self.get_original_clean_instructions() original_index = only(i for i, inst in enumerate(original_instructions) if inst.offset == self.lasti) for i, expr in enumerate(exprs): setter = get_setter(expr) # noinspection PyArgumentList replacement = ast.BinOp( left=expr, op=ast.Pow(), right=ast.Str(s=sentinel), ) ast.fix_missing_locations(replacement) setter(replacement) try: instructions = self.compile_instructions() finally: setter(expr) indices = [ i for i, instruction in enumerate(instructions) if instruction.argval == sentinel ] # There can be several indices when the bytecode is duplicated, # as happens in a finally block in 3.9+ # First we remove the opcodes caused by our modifications for index_num, sentinel_index in enumerate(indices): # Adjustment for removing sentinel instructions below # in past iterations sentinel_index -= index_num * 2 assert_( instructions.pop(sentinel_index).opname == 'LOAD_CONST') assert_( instructions.pop(sentinel_index).opname == 'BINARY_POWER') # Then we see if any of the instruction indices match for index_num, sentinel_index in enumerate(indices): sentinel_index -= index_num * 2 new_index = sentinel_index - 1 if new_index != original_index: continue original_inst = original_instructions[original_index] new_inst = instructions[new_index] # In Python 3.9+, changing 'not x in y' to 'not sentinel_transformation(x in y)' # changes a CONTAINS_OP(invert=1) to CONTAINS_OP(invert=0),<sentinel stuff>,UNARY_NOT if (original_inst.opname == new_inst.opname in ('CONTAINS_OP', 'IS_OP') and original_inst.arg != new_inst.arg and (original_instructions[original_index + 1].opname != instructions[new_index + 1].opname == 'UNARY_NOT')): # Remove the difference for the upcoming assert instructions.pop(new_index + 1) # Check that the modified instructions don't have anything unexpected for inst1, inst2 in zip_longest(original_instructions, instructions): assert_( inst1.opname == inst2.opname or all('JUMP_IF_' in inst.opname for inst in [inst1, inst2]) or all(inst.opname in ('JUMP_FORWARD', 'JUMP_ABSOLUTE') for inst in [inst1, inst2]) or (inst1.opname == 'PRINT_EXPR' and inst2.opname == 'POP_TOP') or (inst1.opname in ('LOAD_METHOD', 'LOOKUP_METHOD') and inst2.opname == 'LOAD_ATTR') or (inst1.opname == 'CALL_METHOD' and inst2.opname == 'CALL_FUNCTION'), (inst1, inst2, ast.dump(expr), expr.lineno, self.code.co_filename)) yield expr