def rewrite(self): self.process_block_to_be_embedded(self.ast.root) # Find captures cond = ast2.Assign(targets=[ ast2.FieldAccess(object=ast2.Name("block_literal"), field=Any()) ], value=Any(), decltype=None) result = AstMatcher().replace(self.ast.root, cond, self.callback_find_captures) # Replace block literal with ^{ ... } self.did_replace = False cond = ast2.AddressOf(variable=ast2.Name("block_literal")) result = AstMatcher().replace(self.ast.root, cond, self.callback) # bind captures if self.did_replace: cond = ast2.Assign(targets=[ ast2.FieldAccess(object=ast2.Name("block_literal"), field=Any()) ], value=Any(), decltype=None) result = AstMatcher().replace(self.ast.root, cond, self.callback2) # remove block_literal cond = ast2.Declaration(typename=Any(), name="block_literal") result = AstMatcher().replace(self.ast.root, cond, self.callback_remove_decl)
def walk(self, node, parent=None): if isinstance(node, list): for idx, listitem in enumerate(node): self.walk(listitem, parent) if isinstance(listitem, ast2.Statement) and isinstance( parent, ast2.If): myif = parent call = listitem.expr if isinstance(call, ast2.Call): if isinstance(call.func, ast2.Name): if call.func.id == "_dispatch_once": decl = ast2.Declaration( typename="static dispatch_once_t", name="onceToken") args = [ ast2.AddressOf(variable=ast2.Name( id="onceToken")), call.args[1], ] call = ast2.Statement( expr=ast2.Call(func=ast2.Name( id="dispatch_once"), args=args)) node[:] = [decl, call] parent.test = ast2.Num(n=1) break elif isinstance(node, _ast.AST): for field in node.__class__._fields: subnode = getattr(node, field) self.walk(subnode, node)
def process_inner_block_root(self, root): cond = ast2.BinOp(left=ast2.Name("block_literal"), op=ast2.Add(), right=ast2.Num(n=Any())) result = AstMatcher().replace(root, cond, self.callback_capture_usage) cond = ast2.FieldAccess(object=ast2.Name("block_literal"), field=Any()) result = AstMatcher().replace(root, cond, self.callback_capture_usage2) pass
def rewrite(self): self.walk(self.ast.root.body, self.ast.root) if self.mutation_call is None: return if not self.mutation_call in self.parents: return statement = self.parents[self.mutation_call] if not statement in self.parents: return if1 = self.parents[statement] if not if1 in self.parents: return while1 = self.parents[if1] if not while1 in self.parents: return while2 = self.parents[while1] if not while2 in self.parents: return if2 = self.parents[while2] if not isinstance(statement, ast2.Statement): return if not isinstance(if1, ast2.If): return if not isinstance(if2, ast2.If): return if not isinstance(while1, ast2.DoWhile): return if not isinstance(while2, ast2.DoWhile): return # find var if_in_while_idx = while1.body.index(if1) deref_idx = 0 var = ast2.Name('i') for (idx, s) in enumerate(while1.body[if_in_while_idx + 1:]): if isinstance(s, ast2.Assign): if isinstance(s.value, ast2.Dereference) and isinstance( s.targets[0], ast2.Name): var = s.targets[0] deref_idx = idx + if_in_while_idx + 2 break while1.body = while1.body[deref_idx:] # find collection collection = ast2.Name("collection") for (idx, s) in enumerate(self.ast.root.body): if isinstance(s, ast2.Assign): if isinstance(s.value, ast2.ObjCMessageSend) and isinstance( s.targets[0], ast2.Name): if s.value.selector == "countByEnumeratingWithState:objects:count:": collection = s.targets[0] self.ast.root.body[idx] = ast2.Statement(ast2.Num(n=1)) break if2.test = ast2.Num(n=1) if2.body = [ ast2.ForEach(typename='id', variable=var, source=collection, body=while1.body) ]
def rewrite(self): # TODO other msgSends cond = ast2.Call(ast2.Name("_objc_msgSend"), Any()) AstMatcher().replace(self.ast.root, cond, self.callback) cond = ast2.Call(ast2.Name("_objc_msgSendSuper"), Any()) AstMatcher().replace(self.ast.root, cond, self.callback) cond = ast2.Call(ast2.Name("_objc_msgSendSuper2"), Any()) AstMatcher().replace(self.ast.root, cond, self.callback) cond = ast2.Name(Any()) AstMatcher().replace(self.ast.root, cond, self.callback_class_rewriting)
def declare_argument(self, register, t): n = register.name if n in self.parameter_names: n = self.parameter_names[n] if n not in self.locals: self.locals[n] = None # Implicit declaration (function argument) return ast2.Name(n)
def get_variable(self, register): n = register.name if n in self.parameter_names: n = self.parameter_names[n] if n not in self.locals: self.locals[n] = ast2.Declaration(self.type_for_register(n), n) return ast2.Name(n)
def callback_capture_usage(self, node): if not isinstance(node, ast2.BinOp): return node if not isinstance(node.left, ast2.Name): return node if not node.left.id == "block_literal": return node if not isinstance(node.right, ast2.Num): return node offset = node.right.n if not offset in self.captures: return node return ast2.AddressOf(variable=ast2.Name(id=self.captures[offset]))
def process_block_to_be_embedded(self, root): cond = ast2.Assign(targets=Any(), value=ast2.BinOp(left=ast2.Name("block_literal"), op=ast2.Add(), right=ast2.Num(n=Any())), decltype=None) result = AstMatcher().replace( root, cond, self.callback_process_block_to_be_embedded) pass
def callback_capture_usage2(self, node): if not isinstance(node, ast2.FieldAccess): return node if not isinstance(node.object, ast2.Name): return node if not node.object.id == "block_literal": return node field_name = node.field offset = int(field_name[4:], 16) if not offset in self.captures: return node return ast2.Name(id=self.captures[offset])
def rewrite(self): self.walk(self.ast.root.body) # print self.var_to_def # print self.var_to_use cond = ast2.Name(id=Any()) AstMatcher().replace(self.ast.root, cond, self.callback) cond = ast2.Assign(value=Any(), targets=Any(), decltype=None) AstMatcher().replace(self.ast.root, cond, self.callback)
def callback(self, node): if not len(node.args) >= 2: return node if not isinstance(node.args[1], ast2.ObjCSelector): return node selector = node.args[1].value if node.func.id == "_objc_msgSend": receiver = node.args[0] elif node.func.id in ["_objc_msgSendSuper", "_objc_msgSendSuper2"]: receiver = ast2.Name("super") else: return node return ast2.ObjCMessageSend(receiver, selector, node.args[2:])
def callback_process_block_to_be_embedded(self, node): if not isinstance(node, ast2.Assign): return node target = node.targets[0] binop = node.value if not isinstance(binop, ast2.BinOp): return node if not isinstance(binop.left, ast2.Name): return node if not binop.left.id == "block_literal": return node if not isinstance(binop.right, ast2.Num): return node offset = binop.right.n node.value = ast2.AddressOf(variable=ast2.FieldAccess( object=ast2.Name("block_literal"), field="off_%x" % offset)) return node
def rewrite(self): self.flatten_lists(self.ast.root.body, self.list_removal) # remove final return if isinstance(self.ast.root.body, list): last_stmt = self.ast.root.body[-1] if isinstance(last_stmt, ast2.Return): if last_stmt.value is None: del (self.ast.root.body[-1]) cond = ast2.BinOp(left=ast2.Name(id=Any()), op=Any(), right=ast2.Name(id=Any())) AstMatcher().replace(self.ast.root, cond, self.callback1) ForeachRewriter(self.ast).rewrite() AddressOfRewriter(self.ast).rewrite() PropagateRewriter(self.ast).rewrite() UnusedVarsRewriter(self.ast).rewrite() ExpressionEmbedRewriter(self.ast).rewrite() DeclarationToDefinitionRewriter(self.ast).rewrite() DispatchOnceRewriter(self.ast).rewrite() IfRemoveRewriter(self.ast).rewrite()
def convert1(self, o): c = self.convert1 if isinstance(o, UCodeAdd): return ast2.Assign(targets=[c(o.destination())], value=ast2.BinOp(c(o.source1()), ast2.Add(), c(o.source2())), decltype=None) elif isinstance(o, UCodeRet): if len(o.operands) == 0: return ast2.Return(None) else: assert(len(o.operands) == 1) return ast2.Return(c(o.operands[0])) elif isinstance(o, UCodeCall): call = ast2.Call(c(o.callee()), [c(p) for p in o.params()]) if o.has_destination: return ast2.Assign(targets=[c(o.destination())], value=call, decltype=None) else: return ast2.Statement(expr=call) elif isinstance(o, UCodeMov): return ast2.Assign(targets=[c(o.destination())], value=c(o.source()), decltype=None) elif isinstance(o, UCodeRegister): return self.get_variable(o) elif isinstance(o, UCodeStore): deref = ast2.Dereference(c(o.pointer())) return ast2.Assign(targets=[deref], value=c(o.source()), decltype=None) elif isinstance(o, UCodeLoad): deref = ast2.Dereference(c(o.pointer())) return ast2.Assign(targets=[c(o.destination())], value=deref, decltype=None) elif isinstance(o, UCodeAddressOfLocal): return ast2.Assign(targets=[c(o.destination())], value=ast2.AddressOf(variable=c(o.source())), decltype=None) elif isinstance(o, UCodeSetMember): field = "off_%x" % o.offset() target = ast2.FieldAccess(object=c(o.destination()), field=field) return ast2.Assign(targets=[target], value=c(o.value()), decltype=None) elif isinstance(o, UCodeGetMember): field = "off_%x" % o.offset() value = ast2.FieldAccess(object=c(o.value()), field=field) return ast2.Assign(targets=[c(o.destination())], value=value, decltype=None) elif isinstance(o, UCodeBranch): return ast2.Statement(c(o.condition())) elif isinstance(o, UCodeSetFlag): if o.type == UCodeSetFlag.TYPE_ZERO and o.operation == UCodeSetFlag.OPERATION_SUB: comparison = ast2.Equals(left=c(o.source1()), right=c(o.source2())) elif o.type == UCodeSetFlag.TYPE_ZERO and o.operation == UCodeSetFlag.OPERATION_AND: comparison = ast2.BinOp(left=c(o.source1()), op=ast2.And(), right=c(o.source2())) elif o.type == UCodeSetFlag.TYPE_CARRY and o.operation == UCodeSetFlag.OPERATION_SUB: comparison = ast2.BinOp(left=c(o.source1()), op=ast2.Gt(), right=c(o.source2())) else: comparison = ast2.Todo(str(o)) target = c(o.destination()) return ast2.Assign(targets=[target], value=comparison, decltype=None) elif isinstance(o, UCodeNeg): target = c(o.destination()) value = ast2.Negation(value=c(o.source1())) return ast2.Assign(targets=[target], value=value, decltype=None) elif isinstance(o, UCodeTruncate) or isinstance(o, UCodeExtend): target = c(o.destination()) value = c(o.source()) return ast2.Assign(targets=[target], value=value, decltype=None) elif isinstance(o, UCodeArithmeticOperation): target = c(o.destination()) mnem_to_op = {"uADD": ast2.Add, "uSUB": ast2.Sub, "uMUL": ast2.Mult, "uDIV": ast2.Div, "uMOD": ast2.Mod, "uAND": ast2.And, "uOR": ast2.Or, "uXOR": ast2.Xor, "uSHL": ast2.Shl, "uSHR": ast2.Shr} op = mnem_to_op[o.mnem()]() value = ast2.BinOp(left=c(o.source1()), op=op, right=c(o.source2())) return ast2.Assign(targets=[target], value=value, decltype=None) elif isinstance(o, UCodeInstruction): return ast2.Statement(ast2.Todo(str(o))) # TODO elif isinstance(o, UCodeConstant): if not o.name: return ast2.Num(n=o.value) if o.name.startswith("_OBJC_CFSTRING_$_"): if o.value in self.func.binary.cfstrings: return ast2.ObjCString(value=self.func.binary.cfstrings[o.value].string) if o.name.startswith("_OBJC_SELECTOR_$_"): return ast2.ObjCSelector(value=o.name.replace("_OBJC_SELECTOR_$_", "")) if o.name not in self.globals: self.globals[o.name] = ast2.Declaration("long", o.name) return ast2.Name(o.name) else: return ast2.Todo(str(o)) # TODO