def FixNodeRequiringBoolInt(ast: c_ast.Node, meta_info): candidates = common.FindMatchingNodesPostOrder(ast, ast, IsNodeRequiringBoolInt) meta_info.type_links[CONST_ZERO] = meta.INT_IDENTIFIER_TYPE for node, parent in candidates: if isinstance(node, c_ast.If): if not IsExprOfTypeBoolInt(node.cond): node.cond = c_ast.BinaryOp("!=", node.cond, CONST_ZERO) meta_info.type_links[node.cond] = meta.INT_IDENTIFIER_TYPE elif isinstance(node, c_ast.For) and node.cond: if not IsExprOfTypeBoolInt(node.cond): node.cond = c_ast.BinaryOp("!=", node.cond, CONST_ZERO) meta_info.type_links[node.cond] = meta.INT_IDENTIFIER_TYPE elif isinstance(node, c_ast.UnaryOp) and node.op == "!": if not IsExprOfTypeBoolInt(node.expr): node = c_ast.BinaryOp( "==", node.expr, CONST_ZERO) # note: we are replacing the "!" node meta_info.type_links[node] = meta.INT_IDENTIFIER_TYPE elif isinstance( node, c_ast.BinaryOp) and node.op in common.SHORT_CIRCUIT_OPS: if not IsExprOfTypeBoolInt(node.left): node.left = c_ast.BinaryOp("!=", node.left, CONST_ZERO) meta_info.type_links[node.left] = meta.INT_IDENTIFIER_TYPE if not IsExprOfTypeBoolInt(node.right): node.right = c_ast.BinaryOp("!=", node.right, CONST_ZERO) meta_info.type_links[node.right] = meta.INT_IDENTIFIER_TYPE
def LiftStaticAndExternToGlobalScope(ast: c_ast.FileAST, meta_info: meta.MetaInfo, id_gen: common.UniqueId): """Requires that if statements only have gotos Why the constraint? We want all variables we need to allocate memory for to be at the top level. """ candidates = common.FindMatchingNodesPostOrder(ast, ast, IsExternOrStatic) for decl, parent in candidates: if "static" in decl.storage: new_name = id_gen.next("__static") + "_" + decl.name decl.storage.remove("static") RenameSymbol(decl, new_name, meta_info) if isinstance(decl.type, (c_ast.TypeDecl, c_ast.ArrayDecl, c_ast.PtrDecl)) and parent != ast: # rip it out stmts = common.GetStatementList(parent) assert stmts, parent stmts.remove(decl) # TODO: insert it just outside the function ast.ext.insert(0, decl) elif isinstance(decl.type, c_ast.FuncDecl): pass else: assert False, decl if "extern" in decl.storage: decl.storage.remove("extern") if ast != parent: stmts = common.GetStatementList(parent) assert stmts, parent stmts.remove(decl) # TODO: insert it just outside the function ast.ext.insert(0, decl)
def ConvertArrayIndexToPointerDereference(ast, meta_info): """ Eliminates multi-dimensional arrays Phase 1: a[1][2] = b; -> *(a + 1 * 10 + 2) = b; Phase 2: fun (int a[5][10]) -> fun (int a[][10]) Phase 3: int a[5][10]; -> int a[50]; """ def IsArrayRefChainHead(node, parent): if not isinstance(node, c_ast.ArrayRef): return False name_type = meta_info.type_links[node.name] # int **b = a; # printf("%d\n", b[1][1]); if not isinstance(name_type, c_ast.ArrayDecl): return True if not isinstance(parent, c_ast.ArrayRef): return True return False ref_chains = common.FindMatchingNodesPostOrder(ast, ast, IsArrayRefChainHead) for chain_head, parent in ref_chains: name, s = MakeCombinedSubscript(chain_head, meta_info) if s is None: addr = name else: addr = c_ast.BinaryOp("+", name, s) head_type = meta_info.type_links[chain_head] # TODO: low confidence - double check this meta_info.type_links[addr] = meta_info.type_links[name] if isinstance(head_type, c_ast.ArrayDecl): # the array ref sequence only partially indexes the array, so the result is just an address common.ReplaceNode(parent, chain_head, addr) else: deref = c_ast.UnaryOp("*", addr) meta_info.type_links[deref] = meta_info.type_links[chain_head] # expression has not changed common.ReplaceNode(parent, chain_head, deref) # Phase 2 def IsArrayDeclParam(node, parent): if not isinstance(parent, c_ast.ParamList): return False if isinstance(node, c_ast.EllipsisParam): return False return isinstance(node.type, c_ast.ArrayDecl) decl_params = common.FindMatchingNodesPreOrder( ast, ast, IsArrayDeclParam) for param, _ in decl_params: t = param.type t.dim = None # Phase 3 def IsArrayDeclChainHead(node, parent): if not isinstance(node, c_ast.ArrayDecl): return False return not isinstance(parent, c_ast.ArrayDecl) decl_chains = common.FindMatchingNodesPreOrder( ast, ast, IsArrayDeclChainHead) for chain_head, parent in decl_chains: CollapseArrayDeclChain(chain_head)
def ConvertArrayStructRef(ast: c_ast.FileAST): def IsArrowStructRef(node, _): return isinstance(node, c_ast.StructRef) and node.type == "->" candidates = common.FindMatchingNodesPostOrder(ast, ast, IsArrowStructRef) for struct_ref, parent in candidates: struct_ref.type = "." struct_ref.name = c_ast.UnaryOp("*", struct_ref.name)
def IfTransform(ast: c_ast.Node, id_gen: common.UniqueId): """ make sure that there is not expression list inside the condition and that the true and false consist of at most a goto. This should be run after the loop conversions""" candidates = common.FindMatchingNodesPostOrder(ast, ast, lambda n, _: isinstance(n, c_ast.If)) for if_stmt, parent in candidates: ConvertToGotos(if_stmt, parent, id_gen)
def SimplifyAddressExpressions(ast, meta_info: meta.MetaInfo): def IsAddressOfDeref(node, _): return (isinstance(node, c_ast.UnaryOp) and isinstance(node.expr, c_ast.UnaryOp) and node.op == "&" and node.expr.op == "*") def IsDerefOfAddress(node, _): return (isinstance(node, c_ast.UnaryOp) and isinstance(node.expr, c_ast.UnaryOp) and node.op == "*" and node.expr.op == "&") # we need to split these in case of "&*&*c" candidates = common.FindMatchingNodesPostOrder(ast, ast, IsAddressOfDeref) for node, parent in candidates: common.ReplaceNode(parent, node, node.expr.expr) candidates = common.FindMatchingNodesPostOrder(ast, ast, IsDerefOfAddress) for node, parent in candidates: common.ReplaceNode(parent, node, node.expr.expr)
def ConvertPreIncDecToCompoundAssignment(ast, meta_info): def IsPreIncDec(node, _parent): return isinstance(node, c_ast.UnaryOp) and node.op in common.PRE_INC_DEC_OPS candidates = common.FindMatchingNodesPostOrder(ast, ast, IsPreIncDec) meta_info.type_links[CONST_ONE] = meta.INT_IDENTIFIER_TYPE for node, parent in candidates: op = "+=" if node.op == "++" else "-=" a = c_ast.Assignment(op, node.expr, CONST_ONE) meta_info.type_links[a] = meta_info.type_links[node.expr] common.ReplaceNode(parent, node, a)
def ForwardGotosAndRemoveUnusedLabels(node: c_ast.Node, forwards: Mapping[str, str]): def IsGotoOrLabel(node, _): return isinstance(node, (c_ast.Goto, c_ast.Label)) candidates = common.FindMatchingNodesPostOrder(node, node, IsGotoOrLabel) for node, parent in candidates: if isinstance(node, c_ast.Goto): while node.name in forwards: node.name = forwards[node.name] elif isinstance(node, c_ast.Label) and node.name in forwards: stmts = common.GetStatementList(parent) assert stmts, parent stmts.remove(node)
def PrintfSplitterTransform(ast: c_ast.FileAST, use_specialized_printf): candidates = common.FindMatchingNodesPostOrder(ast, ast, _IsSuitablePrintf) for call, parent in candidates: _DoPrintfSplitter(call, parent, use_specialized_printf) if not use_specialized_printf or len(candidates) == 0: return # remove old printf prototype (use ellipsis) ext = ast.ext to_be_deleted = [] for node in ext: if isinstance(node, c_ast.Decl) and node.name == "printf": to_be_deleted.append(node) for node in to_be_deleted: ext.remove(node)
def EliminateExpressionLists(ast): """This is works best after for loop conversions TODO: make this work for all cases currently we duplicate some work in the "if transform" """ def IsExpressionList(node: c_ast.Node, parent): return isinstance( node, c_ast.ExprList) and not isinstance(parent, c_ast.FuncCall) candidates = common.FindMatchingNodesPostOrder(ast, ast, IsExpressionList) for node, parent in candidates: stmts = common.GetStatementList(parent) if stmts: pos = stmts.index(node) stmts[pos:pos + 1] = node.exprs
def ConvertCompoundAssignment(ast: c_ast.Node, meta_info: meta.MetaInfo, _id_gen): """This works best after ConvertArrayIndexToPointerDereference""" candidates = common.FindMatchingNodesPostOrder(ast, ast, lambda n, _: isinstance(n, c_ast.Assignment)) for assign, parent in candidates: if assign.op == "=": continue lvalue = assign.lvalue if isinstance(lvalue, c_ast.ID): node = c_ast.BinaryOp(assign.op[:-1], lvalue, assign.rvalue) meta_info.type_links[node] = meta_info.type_links[assign] assign.rvalue = node assign.op = "=" elif isinstance(lvalue, c_ast.UnaryOp) and lvalue.op == "*": # TODO pass
def ConvertWhileLoop(ast, id_gen: common.UniqueId): def IsWhileLoop(node, _): return isinstance(node, (c_ast.DoWhile, c_ast.While)) candidates = common.FindMatchingNodesPostOrder(ast, ast, IsWhileLoop) for node, parent in candidates: loop_label = id_gen.next("while") test_label = loop_label + "_cond" exit_label = loop_label + "_exit" conditional = c_ast.If(node.cond, c_ast.Goto(loop_label), None) block = [c_ast.Label(loop_label, c_ast.EmptyStatement()), node.stmt, c_ast.Label(test_label, c_ast.EmptyStatement()), conditional, c_ast.Label(exit_label, c_ast.EmptyStatement())] if isinstance(node, c_ast.While): block = [c_ast.Goto(test_label)] + block common.ReplaceBreakAndContinue(node.stmt, node, test_label, exit_label) common.ReplaceNode(parent, node, c_ast.Compound(block))
def PrintfSplitterTransform(ast: c_ast.FileAST, use_specialized_printf): candidates = common.FindMatchingNodesPostOrder(ast, ast, _IsSuitablePrintf) for call, parent in candidates: _DoPrintfSplitter(call, parent, use_specialized_printf) if not use_specialized_printf or len(candidates) == 0: return # remove old prototypes ext = ast.ext to_be_deleted = [] for node in ext: if isinstance(node, c_ast.Decl) and node.name in {"puts", "printf"}: to_be_deleted.append(node) for node in to_be_deleted: ext.remove(node) # prepend protos ext[0:0] = [PUTS] + list(PRINTF_PROTOTYPES.values())
def ConvertForLoop(ast, id_gen: common.UniqueId): candidates = common.FindMatchingNodesPostOrder( ast, ast, lambda n, _: isinstance(n, c_ast.For)) for node, parent in candidates: loop_label = id_gen.next("for") next_label = loop_label + "_next" test_label = loop_label + "_cond" exit_label = loop_label + "_exit" goto = c_ast.Goto(loop_label) conditional = c_ast.If(node.cond, goto, None) if node.cond else goto block = ExtractForInitStatements(node.init) + [ c_ast.Goto(test_label), c_ast.Label(loop_label, c_ast.EmptyStatement()), node.stmt, c_ast.Label(next_label, c_ast.EmptyStatement()), node.next if node.next else c_ast.EmptyStatement(), c_ast.Label(test_label, c_ast.EmptyStatement()), conditional, c_ast.Label(exit_label, c_ast.EmptyStatement()) ] common.ReplaceBreakAndContinue(node.stmt, node, next_label, exit_label) common.ReplaceNode(parent, node, c_ast.Compound(block))
def ConvertPostToPreIncDec(ast: c_ast.Node): candidates = common.FindMatchingNodesPostOrder(ast, ast, IsSuitablePostIncDec) for node, parent in candidates: node.op = node.op[1:] # strip out the leading "p"
def ShortCircuitIfTransform(ast: c_ast.Node, id_gen: common.UniqueId): """Requires that if statements only have gotos""" candidates = common.FindMatchingNodesPostOrder(ast, ast, lambda n, _: isinstance(n, c_ast.If)) for if_stmt, parent in candidates: ConvertShortCircuitIf(if_stmt, parent, id_gen)