def ConvertArrayIndexToPointerDereference(ast, meta_info): """ Eliminates multi-dimensional arrays Phase 1: a[1][2] = b; -> *(a + 1 * 10 + 2) = b; Phase 2: fun (int a[5][10]) -> fun (int a[][10]) Phase 3: int a[5][10]; -> int a[50]; """ def IsArrayRefChainHead(node, parent): if not isinstance(node, c_ast.ArrayRef): return False name_type = meta_info.type_links[node.name] # int **b = a; # printf("%d\n", b[1][1]); if not isinstance(name_type, c_ast.ArrayDecl): return True if not isinstance(parent, c_ast.ArrayRef): return True return False ref_chains = common.FindMatchingNodesPostOrder(ast, ast, IsArrayRefChainHead) for chain_head, parent in ref_chains: name, s = MakeCombinedSubscript(chain_head, meta_info) if s is None: addr = name else: addr = c_ast.BinaryOp("+", name, s) head_type = meta_info.type_links[chain_head] # TODO: low confidence - double check this meta_info.type_links[addr] = meta_info.type_links[name] if isinstance(head_type, c_ast.ArrayDecl): # the array ref sequence only partially indexes the array, so the result is just an address common.ReplaceNode(parent, chain_head, addr) else: deref = c_ast.UnaryOp("*", addr) meta_info.type_links[deref] = meta_info.type_links[chain_head] # expression has not changed common.ReplaceNode(parent, chain_head, deref) # Phase 2 def IsArrayDeclParam(node, parent): if not isinstance(parent, c_ast.ParamList): return False if isinstance(node, c_ast.EllipsisParam): return False return isinstance(node.type, c_ast.ArrayDecl) decl_params = common.FindMatchingNodesPreOrder( ast, ast, IsArrayDeclParam) for param, _ in decl_params: t = param.type t.dim = None # Phase 3 def IsArrayDeclChainHead(node, parent): if not isinstance(node, c_ast.ArrayDecl): return False return not isinstance(parent, c_ast.ArrayDecl) decl_chains = common.FindMatchingNodesPreOrder( ast, ast, IsArrayDeclChainHead) for chain_head, parent in decl_chains: CollapseArrayDeclChain(chain_head)
def _DoPrintfSplitter(call: c_ast.FuncCall, parent, use_specialized_printf): args: List = call.args.exprs fmt_pieces = TokenizeFormatString(args[0].value[1:-1]) assert len(fmt_pieces) >= 1 if use_specialized_printf and len(fmt_pieces) == 1: s = fmt_pieces[0] if len(s) <= 1 or s[0] != "%": call.name.name = "write_s" args.insert(0, CONST_STDOUT) return else: call.name.name = GetSingleArgPrintForFormat(s) args[0] = CONST_STDOUT return stmts = common.GetStatementList(parent) if not stmts: stmts = [call] common.ReplaceNode(parent, call, c_ast.Compound(stmts)) calls = [] args = args[1:] # skip the format string # note this has a small bug: we should evaluate all the # args and then print them instead of interleaaving # computation and printing. for f in fmt_pieces: arg = None if f[0] == '%' and len(f) > 1: arg = args.pop(0) c = MakePrintfCall(f, arg, use_specialized_printf) calls.append(c) pos = stmts.index(call) stmts[pos:pos + 1] = calls
def SimpleConstantFolding(node, parent): if (isinstance(node, c_ast.UnaryOp) and node.op == "-" and isinstance(node.expr, c_ast.Constant)): node.expr.value = "-" + node.expr.value common.ReplaceNode(parent, node, node.expr) for c in node: SimpleConstantFolding(c, node)
def SimplifyAddressExpressions(ast, meta_info: meta.MetaInfo): def IsAddressOfDeref(node, _): return (isinstance(node, c_ast.UnaryOp) and isinstance(node.expr, c_ast.UnaryOp) and node.op == "&" and node.expr.op == "*") def IsDerefOfAddress(node, _): return (isinstance(node, c_ast.UnaryOp) and isinstance(node.expr, c_ast.UnaryOp) and node.op == "*" and node.expr.op == "&") # we need to split these in case of "&*&*c" candidates = common.FindMatchingNodesPostOrder(ast, ast, IsAddressOfDeref) for node, parent in candidates: common.ReplaceNode(parent, node, node.expr.expr) candidates = common.FindMatchingNodesPostOrder(ast, ast, IsDerefOfAddress) for node, parent in candidates: common.ReplaceNode(parent, node, node.expr.expr)
def ConvertPreIncDecToCompoundAssignment(ast, meta_info): def IsPreIncDec(node, _parent): return isinstance(node, c_ast.UnaryOp) and node.op in common.PRE_INC_DEC_OPS candidates = common.FindMatchingNodesPostOrder(ast, ast, IsPreIncDec) meta_info.type_links[CONST_ONE] = meta.INT_IDENTIFIER_TYPE for node, parent in candidates: op = "+=" if node.op == "++" else "-=" a = c_ast.Assignment(op, node.expr, CONST_ONE) meta_info.type_links[a] = meta_info.type_links[node.expr] common.ReplaceNode(parent, node, a)
def ConvertConvertAddressTakenScalarsToArray(ast, meta_info: meta.MetaInfo): """ Rewrite address taken scalar vars as one element arrays After this transform we can keep all scalars in registers. """ def IsAddressTakenScalarOrGlobalScalar(node, parent): if isinstance(node, c_ast.Decl) and IsScalarType(node.type): # return isinstance(parent, c_ast.FileAST) return (isinstance(parent, c_ast.FileAST) or "static" in node.storage) if not isinstance(node, c_ast.UnaryOp): return False if node.op != "&": return False if not isinstance(node.expr, c_ast.ID): return False type = meta_info.type_links[node.expr] return IsScalarType(type) candidates = common.FindMatchingNodesPreOrder(ast, ast, IsAddressTakenScalarOrGlobalScalar) ids = set() for node, _ in candidates: if isinstance(node, c_ast.UnaryOp): ids.add(meta_info.sym_links[node.expr]) else: ids.add(node) one = c_ast.Constant("int", "1") meta_info.type_links[one] = meta.INT_IDENTIFIER_TYPE for node in ids: assert isinstance(node, c_ast.Decl) node.type = c_ast.ArrayDecl(node.type, one, []) if node.init: node.init = c_ast.InitList([node.init]) def IsAddressTakenScalarId(node, _): return isinstance(node, c_ast.ID) and meta_info.sym_links[node] in ids candidates = common.FindMatchingNodesPreOrder(ast, ast, IsAddressTakenScalarId) for node, parent in candidates: original_type = meta_info.type_links[node] meta_info.type_links[node] = meta.GetTypeForDecl(meta_info.sym_links[node].type) array_ref = c_ast.UnaryOp("*", node) meta_info.type_links[array_ref] = original_type common.ReplaceNode(parent, node, array_ref)
def ConvertWhileLoop(ast, id_gen: common.UniqueId): def IsWhileLoop(node, _): return isinstance(node, (c_ast.DoWhile, c_ast.While)) candidates = common.FindMatchingNodesPostOrder(ast, ast, IsWhileLoop) for node, parent in candidates: loop_label = id_gen.next("while") test_label = loop_label + "_cond" exit_label = loop_label + "_exit" conditional = c_ast.If(node.cond, c_ast.Goto(loop_label), None) block = [c_ast.Label(loop_label, c_ast.EmptyStatement()), node.stmt, c_ast.Label(test_label, c_ast.EmptyStatement()), conditional, c_ast.Label(exit_label, c_ast.EmptyStatement())] if isinstance(node, c_ast.While): block = [c_ast.Goto(test_label)] + block common.ReplaceBreakAndContinue(node.stmt, node, test_label, exit_label) common.ReplaceNode(parent, node, c_ast.Compound(block))
def ConvertForLoop(ast, id_gen: common.UniqueId): candidates = common.FindMatchingNodesPostOrder( ast, ast, lambda n, _: isinstance(n, c_ast.For)) for node, parent in candidates: loop_label = id_gen.next("for") next_label = loop_label + "_next" test_label = loop_label + "_cond" exit_label = loop_label + "_exit" goto = c_ast.Goto(loop_label) conditional = c_ast.If(node.cond, goto, None) if node.cond else goto block = ExtractForInitStatements(node.init) + [ c_ast.Goto(test_label), c_ast.Label(loop_label, c_ast.EmptyStatement()), node.stmt, c_ast.Label(next_label, c_ast.EmptyStatement()), node.next if node.next else c_ast.EmptyStatement(), c_ast.Label(test_label, c_ast.EmptyStatement()), conditional, c_ast.Label(exit_label, c_ast.EmptyStatement()) ] common.ReplaceBreakAndContinue(node.stmt, node, next_label, exit_label) common.ReplaceNode(parent, node, c_ast.Compound(block))
def ConvertToGotos(if_stmt: c_ast.If, parent, id_gen: common.UniqueId): if (isinstance(if_stmt.iftrue, c_ast.Goto) and isinstance(if_stmt.iffalse, c_ast.Goto) and not isinstance(if_stmt.cond, c_ast.ExprList)): return label = id_gen.next("if") labeltrue = label + "_true" labelfalse = label + "_false" labelend = label + "_end" emptytrue = common.IsEmpty(if_stmt.iftrue) or isinstance(if_stmt.iftrue, c_ast.Goto) emptyfalse = common.IsEmpty(if_stmt.iffalse) or isinstance(if_stmt.iffalse, c_ast.Goto) seq: List[c_ast.Node] = [] # TODO: this should be done in EliminateExpressionLists( if isinstance(if_stmt.cond, c_ast.ExprList): exprs = if_stmt.cond.exprs if_stmt.cond = exprs.pop(-1) seq += exprs seq.append(if_stmt) if not emptytrue: seq += [c_ast.Label(labeltrue, c_ast.EmptyStatement()), if_stmt.iftrue] if not emptyfalse: seq.append(c_ast.Goto(labelend)) if not emptyfalse: seq += [c_ast.Label(labelfalse, c_ast.EmptyStatement()), if_stmt.iffalse] seq.append(c_ast.Label(labelend, c_ast.EmptyStatement())) if not isinstance(if_stmt.iftrue, c_ast.Goto): if_stmt.iftrue = c_ast.Goto(labelend if emptytrue else labeltrue) if not isinstance(if_stmt.iffalse, c_ast.Goto): if_stmt.iffalse = c_ast.Goto(labelend if emptyfalse else labelfalse) stmts = common.GetStatementList(parent) if not stmts: stmts = [if_stmt] parent = common.ReplaceNode(parent, if_stmt, c_ast.Compound(stmts)) pos = stmts.index(if_stmt) stmts[pos: pos + 1] = seq