def test_ClassDef(self): from miasm2.expression.expression import ExprInt32, ExprId, ExprMem, ExprCompose from miasm2.arch.x86.sem import ir_x86_32 from miasm2.ir.symbexec import symbexec addrX = ExprInt32(-1) addr0 = ExprInt32(0) addr1 = ExprInt32(1) addr8 = ExprInt32(8) addr9 = ExprInt32(9) addr20 = ExprInt32(20) addr40 = ExprInt32(40) addr50 = ExprInt32(50) mem0 = ExprMem(addr0) mem1 = ExprMem(addr1) mem8 = ExprMem(addr8) mem9 = ExprMem(addr9) mem20 = ExprMem(addr20) mem40v = ExprMem(addr40, 8) mem40w = ExprMem(addr40, 16) mem50v = ExprMem(addr50, 8) mem50w = ExprMem(addr50, 16) id_x = ExprId('x') id_y = ExprId('y', 8) id_a = ExprId('a') id_eax = ExprId('eax_init') e = symbexec( ir_x86_32(), { mem0: id_x, mem1: id_y, mem9: id_x, mem40w: id_x, mem50v: id_y, id_a: addr0, id_eax: addr0 }) self.assertEqual(e.find_mem_by_addr(addr0), mem0) self.assertEqual(e.find_mem_by_addr(addrX), None) self.assertEqual(e.eval_ExprMem(ExprMem(addr1 - addr1)), id_x) self.assertEqual(e.eval_ExprMem(ExprMem(addr1, 8)), id_y) self.assertEqual( e.eval_ExprMem(ExprMem(addr1 + addr1)), ExprCompose([(id_x[16:32], 0, 16), (ExprMem(ExprInt32(4), 16), 16, 32)])) self.assertEqual( e.eval_ExprMem(mem8), ExprCompose([(id_x[0:24], 0, 24), (ExprMem(ExprInt32(11), 8), 24, 32)])) self.assertEqual(e.eval_ExprMem(mem40v), id_x[:8]) self.assertEqual( e.eval_ExprMem(mem50w), ExprCompose([(id_y, 0, 8), (ExprMem(ExprInt32(51), 8), 8, 16)])) self.assertEqual(e.eval_ExprMem(mem20), mem20) e.func_read = lambda x: x self.assertEqual(e.eval_ExprMem(mem20), mem20) self.assertEqual(set(e.modified()), set(e.symbols)) self.assertRaises(KeyError, e.symbols.__getitem__, ExprMem(ExprInt32(100)))
def simp_subwc_cf(_, expr): """SUBWC_CF(A, B, SUB_CF(C, D)) => SUB_CF({A, C}, {B, D})""" if not expr.is_op('FLAG_SUBWC_CF'): return expr op3 = expr.args[2] if not op3.is_op("FLAG_SUB_CF"): return expr op1 = ExprCompose(expr.args[0], op3.args[0]) op2 = ExprCompose(expr.args[1], op3.args[1]) return ExprOp("FLAG_SUB_CF", op1, op2)
def simp_sign_subwc_cf(expr_s, expr): # SIGN_SUBWC(A, B, SUB_CF(C, D)) => SIGN_SUB({A, C}, {B, D}) if not expr.is_op('FLAG_SIGN_SUBWC'): return expr op3 = expr.args[2] if not op3.is_op("FLAG_SUB_CF"): return expr op1 = ExprCompose(expr.args[0], op3.args[0]) op2 = ExprCompose(expr.args[1], op3.args[1]) return ExprOp("FLAG_SIGN_SUB", op1, op2)
def simp_compose(e_s, expr): "Commons simplification on ExprCompose" args = merge_sliceto_slice(expr) out = [] # compose of compose for arg in args: if arg.is_compose(): out += arg.args else: out.append(arg) args = out # Compose(a) with a.size = compose.size => a if len(args) == 1 and args[0].size == expr.size: return args[0] # {(X[z:], 0, X.size-z), (0, X.size-z, X.size)} => (X >> z) if len(args) == 2 and args[1].is_int(0): if (args[0].is_slice() and args[0].stop == args[0].arg.size and args[0].size + args[1].size == args[0].arg.size): new_expr = args[0].arg >> ExprInt(args[0].start, args[0].arg.size) return new_expr # {@X[base + i] 0 X, @Y[base + i + X] X (X + Y)} => @(X+Y)[base + i] for i, arg in enumerate(args[:-1]): nxt = args[i + 1] if arg.is_mem() and nxt.is_mem(): gap = e_s(nxt.arg - arg.arg) if gap.is_int() and arg.size % 8 == 0 and int(gap) == arg.size / 8: args = args[:i] + [ExprMem(arg.arg, arg.size + nxt.size)] + args[i + 2:] return ExprCompose(*args) # {a, x?b:d, x?c:e, f} => x?{a, b, c, f}:{a, d, e, f} conds = set(arg.cond for arg in expr.args if arg.is_cond()) if len(conds) == 1: cond = list(conds)[0] args1, args2 = [], [] for arg in expr.args: if arg.is_cond(): args1.append(arg.src1) args2.append(arg.src2) else: args1.append(arg) args2.append(arg) arg1 = e_s(ExprCompose(*args1)) arg2 = e_s(ExprCompose(*args2)) return ExprCond(cond, arg1, arg2) return ExprCompose(*args)
def _func_read(self, expr_mem): if not expr_mem.ptr.is_int(): return expr_mem dst_addr = int(expr_mem.ptr) if not self.dse_memory_range: # Trivial case (optimization) return super(ESETrackModif, self)._func_read(expr_mem) # Split access in atomic accesses out = [] for addr in xrange(dst_addr, dst_addr + (expr_mem.size / 8)): if addr in self.dse_memory_range: # Symbolize memory access out.append(self.dse_memory_to_expr(addr)) else: # Get concrete value atomic_access = ExprMem(ExprInt(addr, expr_mem.ptr.size), 8) out.append(super(ESETrackModif, self)._func_read(atomic_access)) if len(out) == 1: # Trivial case (optimization) return out[0] # Simplify for constant merging (ex: {ExprInt(1, 8), ExprInt(2, 8)}) return self.expr_simp(ExprCompose(*out))
def simp_ext(_, expr): if expr.op.startswith('zeroExt_'): arg = expr.args[0] if expr.size == arg.size: return arg return ExprCompose(arg, ExprInt(0, expr.size - arg.size)) if expr.op.startswith("signExt_"): arg = expr.args[0] add_size = expr.size - arg.size new_expr = ExprCompose( arg, ExprCond(arg.msb(), ExprInt(size2mask(add_size), add_size), ExprInt(0, add_size))) return new_expr return expr
def eval_exprcompose(self, expr, **kwargs): """[DEV]: Evaluate an ExprCompose using the current state""" args = [] for arg in expr.args: args.append(self.eval_expr_visitor(arg, **kwargs)) ret = ExprCompose(*args) return ret
def simp_cond(_, expr): """ Common simplifications on ExprCond. Eval exprcond src1/src2 with satifiable/unsatisfiable condition propagation """ if (not expr.cond.is_int()) and expr.cond.size == 1: src1 = expr.src1.replace_expr({expr.cond: ExprInt(1, 1)}) src2 = expr.src2.replace_expr({expr.cond: ExprInt(0, 1)}) if src1 != expr.src1 or src2 != expr.src2: return ExprCond(expr.cond, src1, src2) # -A ? B:C => A ? B:C if expr.cond.is_op('-') and len(expr.cond.args) == 1: expr = ExprCond(expr.cond.args[0], expr.src1, expr.src2) # a?x:x elif expr.src1 == expr.src2: expr = expr.src1 # int ? A:B => A or B elif expr.cond.is_int(): if expr.cond.arg == 0: expr = expr.src2 else: expr = expr.src1 # a?(a?b:c):x => a?b:x elif expr.src1.is_cond() and expr.cond == expr.src1.cond: expr = ExprCond(expr.cond, expr.src1.src1, expr.src2) # a?x:(a?b:c) => a?x:c elif expr.src2.is_cond() and expr.cond == expr.src2.cond: expr = ExprCond(expr.cond, expr.src1, expr.src2.src2) # a|int ? b:c => b with int != 0 elif (expr.cond.is_op('|') and expr.cond.args[1].is_int() and expr.cond.args[1].arg != 0): return expr.src1 # (C?int1:int2)?(A:B) => elif (expr.cond.is_cond() and expr.cond.src1.is_int() and expr.cond.src2.is_int()): int1 = expr.cond.src1.arg.arg int2 = expr.cond.src2.arg.arg if int1 and int2: expr = expr.src1 elif int1 == 0 and int2 == 0: expr = expr.src2 elif int1 == 0 and int2: expr = ExprCond(expr.cond.cond, expr.src2, expr.src1) elif int1 and int2 == 0: expr = ExprCond(expr.cond.cond, expr.src1, expr.src2) elif expr.cond.is_compose(): # {0, X, 0}?(A:B) => X?(A:B) args = [arg for arg in expr.cond.args if not arg.is_int(0)] if len(args) == 1: arg = args.pop() return ExprCond(arg, expr.src1, expr.src2) elif len(args) < len(expr.cond.args): return ExprCond(ExprCompose(*args), expr.src1, expr.src2) return expr
def rev16(ir, instr, arg1, arg2): out = [] for i in xrange(0, arg2.size / 8): index = (i & ~1) + (1 - (i & 1)) out.append(arg2[index * 8:(index + 1) * 8]) e = [] result = ExprCompose(*out) e.append(ExprAssign(arg1, result)) return e, []
def check(self): regs = self.dse.ir_arch.arch.regs value = self.dse.eval_expr(regs.EDX) # The expected value should contains '<<', showing it has been in the # corresponding generated label expected = ExprOp( '<<', regs.EDX, ExprCompose(regs.ECX[0:8], ExprInt(0x0, 24)) & ExprInt(0x1F, 32)) assert value == expected
def rev(ir, instr, arg1, arg2): out = [] for i in xrange(0, arg2.size, 8): out.append(arg2[i:i + 8]) out.reverse() e = [] result = ExprCompose(*out) e.append(ExprAssign(arg1, result)) return e, []
def eval_exprcompose(self, expr, **kwargs): """[DEV]: Evaluate an ExprCompose using the current state""" args = [] for arg in expr.args: arg = self.eval_expr_visitor(arg, **kwargs) if arg.is_id(TOPSTR): return exprid_top(expr) args.append(arg) ret = ExprCompose(*args) return ret
def categorize(self, node, lvl=0, **kwargs): """Recursively apply rules to @node @node: ExprNode to analyze @lvl: actual recusion level """ expr = node.expr log_reduce.debug("\t" * lvl + "Reduce...: %s", node.expr) if isinstance(expr, ExprId): node = ExprNodeId(expr) elif isinstance(expr, ExprInt): node = ExprNodeInt(expr) elif isinstance(expr, ExprLoc): node = ExprNodeLoc(expr) elif isinstance(expr, ExprMem): arg = self.categorize(node.arg, lvl=lvl + 1, **kwargs) node = ExprNodeMem(ExprMem(arg.expr, expr.size)) node.arg = arg elif isinstance(expr, ExprSlice): arg = self.categorize(node.arg, lvl=lvl + 1, **kwargs) node = ExprNodeSlice(ExprSlice(arg.expr, expr.start, expr.stop)) node.arg = arg elif isinstance(expr, ExprOp): new_args = [] for arg in node.args: new_a = self.categorize(arg, lvl=lvl + 1, **kwargs) assert new_a.expr.size == arg.expr.size new_args.append(new_a) node = ExprNodeOp(ExprOp(expr.op, *[x.expr for x in new_args])) node.args = new_args expr = node.expr elif isinstance(expr, ExprCompose): new_args = [] new_expr_args = [] for arg in node.args: arg = self.categorize(arg, lvl=lvl + 1, **kwargs) new_args.append(arg) new_expr_args.append(arg.expr) new_expr = ExprCompose(*new_expr_args) node = ExprNodeCompose(new_expr) node.args = new_args elif isinstance(expr, ExprCond): cond = self.categorize(node.cond, lvl=lvl + 1, **kwargs) src1 = self.categorize(node.src1, lvl=lvl + 1, **kwargs) src2 = self.categorize(node.src2, lvl=lvl + 1, **kwargs) node = ExprNodeCond(ExprCond(cond.expr, src1.expr, src2.expr)) node.cond, node.src1, node.src2 = cond, src1, src2 else: raise TypeError("Unknown Expr Type %r", type(expr)) node.info = self.apply_rules(node, lvl=lvl, **kwargs) log_reduce.debug("\t" * lvl + "Reduce result: %s %r", node.expr, node.info) return node
def casp(ir, instr, arg1, arg2, arg3): # XXX TODO: memory barrier e = [] if arg1.size == 32: regs = gpregs32_expr else: regs = gpregs64_expr index1 = regs.index(arg1) index2 = regs.index(arg2) # TODO endianness comp_value = ExprCompose(regs[index1], regs[index1 + 1]) new_value = ExprCompose(regs[index2], regs[index2 + 1]) assert arg3.is_op('preinc') ptr = arg3.args[0] data = ExprMem(ptr, comp_value.size) loc_store = ExprLoc(ir.loc_db.add_location(), ir.IRDst.size) loc_do = ExprLoc(ir.loc_db.add_location(), ir.IRDst.size) loc_next = ExprLoc(ir.get_next_loc_key(instr), ir.IRDst.size) e.append( ExprAssign( ir.IRDst, ExprCond(ExprOp("FLAG_EQ_CMP", data, comp_value), loc_do, loc_store))) e_store = [] e_store.append(ExprAssign(data, new_value)) e_store.append(ExprAssign(ir.IRDst, loc_do)) blk_store = IRBlock(loc_store.loc_key, [AssignBlock(e_store, instr)]) e_do = [] e_do.append(ExprAssign(regs[index1], data[:data.size / 2])) e_do.append(ExprAssign(regs[index1 + 1], data[data.size / 2:])) e_do.append(ExprAssign(ir.IRDst, loc_next)) blk_do = IRBlock(loc_do.loc_key, [AssignBlock(e_do, instr)]) return e, [blk_store, blk_do]
def mrs(ir, insr, arg1, arg2, arg3, arg4, arg5): e = [] if arg2.is_int(3) and arg3.is_id("c4") and arg4.is_id( "c2") and arg5.is_int(0): out = [] out.append(ExprInt(0x0, 28)) out.append(of) out.append(cf) out.append(zf) out.append(nf) e.append(ExprAssign(arg1, ExprCompose(*out).zeroExtend(arg1.size))) else: raise NotImplementedError("MRS not implemented") return e, []
def read(self, ptr, size): """ Return the value associated with the Expr at address @ptr @ptr: Expr representing the memory address @size: memory size (in bits), byte aligned """ assert size % 8 == 0 base, offset = get_expr_base_offset(ptr) memarray = self.base_to_memarray.get(base, None) if memarray is not None: mems = memarray.read(offset, size) ret = ExprCompose(*mems) else: ret = ExprMem(ptr, size) return ret
def mem_read(self, expr): """ [DEV]: Override to modify the effective memory reads Read symbolic value at ExprMem @expr @expr: ExprMem """ parts = self._resolve_mem_parts(expr) out = [] for known, part in parts: if not known and part.is_mem() and self.func_read is not None: ret = self.func_read(part) else: ret = part out.append(ret) ret = self.expr_simp(ExprCompose(*out)) assert ret.size == expr.size return ret
def test_ClassDef(self): from miasm2.expression.expression import ExprInt, ExprId, ExprMem, \ ExprCompose, ExprAff from miasm2.arch.x86.sem import ir_x86_32 from miasm2.ir.symbexec import SymbolicExecutionEngine from miasm2.ir.ir import AssignBlock addrX = ExprInt(-1, 32) addr0 = ExprInt(0, 32) addr1 = ExprInt(1, 32) addr8 = ExprInt(8, 32) addr9 = ExprInt(9, 32) addr20 = ExprInt(20, 32) addr40 = ExprInt(40, 32) addr50 = ExprInt(50, 32) mem0 = ExprMem(addr0) mem1 = ExprMem(addr1, 8) mem8 = ExprMem(addr8) mem9 = ExprMem(addr9) mem20 = ExprMem(addr20) mem40v = ExprMem(addr40, 8) mem40w = ExprMem(addr40, 16) mem50v = ExprMem(addr50, 8) mem50w = ExprMem(addr50, 16) id_x = ExprId('x') id_y = ExprId('y', 8) id_a = ExprId('a') id_eax = ExprId('eax_init') e = SymbolicExecutionEngine( ir_x86_32(), { mem0: id_x, mem1: id_y, mem9: id_x, mem40w: id_x[:16], mem50v: id_y, id_a: addr0, id_eax: addr0 }) self.assertEqual(e.find_mem_by_addr(addr0), mem0) self.assertEqual(e.find_mem_by_addr(addrX), None) self.assertEqual(e.eval_expr(ExprMem(addr1 - addr1)), id_x) self.assertEqual(e.eval_expr(ExprMem(addr1, 8)), id_y) self.assertEqual(e.eval_expr(ExprMem(addr1 + addr1)), ExprCompose(id_x[16:32], ExprMem(ExprInt(4, 32), 16))) self.assertEqual(e.eval_expr(mem8), ExprCompose(id_x[0:24], ExprMem(ExprInt(11, 32), 8))) self.assertEqual(e.eval_expr(mem40v), id_x[:8]) self.assertEqual(e.eval_expr(mem50w), ExprCompose(id_y, ExprMem(ExprInt(51, 32), 8))) self.assertEqual(e.eval_expr(mem20), mem20) e.func_read = lambda x: x self.assertEqual(e.eval_expr(mem20), mem20) self.assertEqual(set(e.modified()), set(e.symbols)) self.assertRaises(KeyError, e.symbols.__getitem__, ExprMem(ExprInt(100, 32))) self.assertEqual(e.apply_expr(id_eax), addr0) self.assertEqual(e.apply_expr(ExprAff(id_eax, addr9)), addr9) self.assertEqual(e.apply_expr(id_eax), addr9) # apply_change / eval_ir / apply_expr ## x = a (with a = 0x0) assignblk = AssignBlock({id_x: id_a}) e.eval_ir(assignblk) self.assertEqual(e.apply_expr(id_x), addr0) ## x = a (without replacing 'a' with 0x0) e.apply_change(id_x, id_a) self.assertEqual(e.apply_expr(id_x), id_a) ## x = a (with a = 0x0) self.assertEqual(e.apply_expr(assignblk.dst2ExprAff(id_x)), addr0) self.assertEqual(e.apply_expr(id_x), addr0)
""" This example demonstrates the recovering of possible C types for an arbitrary variable in an assembly code (the types are inferred from the function argument types). It also displays the C code used to access this variable. Input: * definitions of the C types that can be used by the code * layout of structures (packed/not packed) * prototype of the analyzed function Algorithm: The DepGraph of the target variable is computed, which gives possible expressions for this variable. For each DepGraph solution, if the expression depends on typed arguments, the code infers the variable type and displays the C code to access this variable. Here be dragons: For the moment, Miasm can infer C types (and generate C) for simple expressions. To summarize, Miasm only supports accesses that do not involve arithmetic or conditional expressions such as the following: * var1.field * var1[12][4] * *(var1.field->tab[4]) Unsupported forms: * var1 + var2 * var1[var2+4] * var1?var2->field:6
expr << (expr_int | expr_id | expr_loc | expr_slice | expr_mem | expr_cond | \ expr_compose | expr_op | expr_aff) def parse_loc_key(t): assert len(t) == 2 loc_key, size = LocKey(t[0]), t[1] return ExprLoc(loc_key, size) expr_int.setParseAction(lambda t: ExprInt(*t)) expr_id.setParseAction(lambda t: ExprId(*t)) expr_loc.setParseAction(parse_loc_key) expr_slice.setParseAction(lambda t: ExprSlice(*t)) expr_mem.setParseAction(lambda t: ExprMem(*t)) expr_cond.setParseAction(lambda t: ExprCond(*t)) expr_compose.setParseAction(lambda t: ExprCompose(*t)) expr_op.setParseAction(lambda t: ExprOp(*t)) expr_aff.setParseAction(lambda t: ExprAssign(*t)) def str_to_expr(str_in): """Parse the @str_in and return the corresponoding Expression @str_in: repr string of an Expression""" try: value = expr.parseString(str_in) except: raise RuntimeError("Cannot parse expression %s" % str_in) assert len(value) == 1 return value[0]
def simp_slice(e_s, expr): "Slice optimization" # slice(A, 0, a.size) => A if expr.start == 0 and expr.stop == expr.arg.size: return expr.arg # Slice(int) => int if expr.arg.is_int(): total_bit = expr.stop - expr.start mask = (1 << (expr.stop - expr.start)) - 1 return ExprInt(int((expr.arg.arg >> expr.start) & mask), total_bit) # Slice(Slice(A, x), y) => Slice(A, z) if expr.arg.is_slice(): if expr.stop - expr.start > expr.arg.stop - expr.arg.start: raise ValueError('slice in slice: getting more val', str(expr)) return ExprSlice(expr.arg.arg, expr.start + expr.arg.start, expr.start + expr.arg.start + (expr.stop - expr.start)) if expr.arg.is_compose(): # Slice(Compose(A), x) => Slice(A, y) for index, arg in expr.arg.iter_args(): if index <= expr.start and index+arg.size >= expr.stop: return arg[expr.start - index:expr.stop - index] # Slice(Compose(A, B, C), x) => Compose(A, B, C) with truncated A/B/C out = [] for index, arg in expr.arg.iter_args(): # arg is before slice start if expr.start >= index + arg.size: continue # arg is after slice stop elif expr.stop <= index: continue # arg is fully included in slice elif expr.start <= index and index + arg.size <= expr.stop: out.append(arg) continue # arg is truncated at start if expr.start > index: slice_start = expr.start - index else: # arg is not truncated at start slice_start = 0 # a is truncated at stop if expr.stop < index + arg.size: slice_stop = arg.size + expr.stop - (index + arg.size) - slice_start else: slice_stop = arg.size out.append(arg[slice_start:slice_stop]) return ExprCompose(*out) # ExprMem(x, size)[:A] => ExprMem(x, a) # XXXX todo hum, is it safe? if (expr.arg.is_mem() and expr.start == 0 and expr.arg.size > expr.stop and expr.stop % 8 == 0): return ExprMem(expr.arg.arg, size=expr.stop) # distributivity of slice and & # (a & int)[x:y] => 0 if int[x:y] == 0 if expr.arg.is_op("&") and expr.arg.args[-1].is_int(): tmp = e_s.expr_simp_wrapper(expr.arg.args[-1][expr.start:expr.stop]) if tmp.is_int(0): return tmp # distributivity of slice and exprcond # (a?int1:int2)[x:y] => (a?int1[x:y]:int2[x:y]) if expr.arg.is_cond() and expr.arg.src1.is_int() and expr.arg.src2.is_int(): src1 = expr.arg.src1[expr.start:expr.stop] src2 = expr.arg.src2[expr.start:expr.stop] return ExprCond(expr.arg.cond, src1, src2) # (a * int)[0:y] => (a[0:y] * int[0:y]) if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int(): args = [e_s.expr_simp_wrapper(a[expr.start:expr.stop]) for a in expr.arg.args] return ExprOp(expr.arg.op, *args) # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size # (a << int)[x:y] => a[x-int:y-int] with x-int >= 0 if (expr.arg.is_op() and expr.arg.op in [">>", "<<"] and expr.arg.args[1].is_int()): arg, shift = expr.arg.args shift = int(shift) if expr.arg.op == ">>": if shift + expr.stop <= arg.size: return arg[expr.start + shift:expr.stop + shift] elif expr.arg.op == "<<": if expr.start - shift >= 0: return arg[expr.start - shift:expr.stop - shift] else: raise ValueError('Bad case') return expr
def extr(arg1, arg2, arg3, arg4): compose = ExprCompose(arg2, arg3) arg1 = compose[int(arg4.arg):int(arg4) + arg1.size]
def test_ClassDef(self): from miasm2.expression.expression import ExprInt, ExprId, ExprMem, \ ExprCompose, ExprAff from miasm2.arch.x86.sem import ir_x86_32 from miasm2.ir.symbexec import SymbolicExecutionEngine from miasm2.ir.ir import AssignBlock id_x = ExprId('x', 32) id_a = ExprId('a', 32) id_b = ExprId('b', 32) id_c = ExprId('c', 32) id_d = ExprId('d', 32) id_e = ExprId('e', 64) sb = SymbolicExecutionEngine(ir_x86_32(), { ExprMem(ExprInt(0x4, 32), 8): ExprInt(0x44, 8), ExprMem(ExprInt(0x5, 32), 8): ExprInt(0x33, 8), ExprMem(ExprInt(0x6, 32), 8): ExprInt(0x22, 8), ExprMem(ExprInt(0x7, 32), 8): ExprInt(0x11, 8), ExprMem(ExprInt(0x20, 32), 32): id_x, ExprMem(ExprInt(0x40, 32), 32): id_x, ExprMem(ExprInt(0x44, 32), 32): id_a, ExprMem(ExprInt(0x54, 32), 32): ExprInt(0x11223344, 32), ExprMem(id_a, 32): ExprInt(0x11223344, 32), id_a: ExprInt(0, 32), id_b: ExprInt(0, 32), ExprMem(id_c, 32): ExprMem(id_d + ExprInt(0x4, 32), 32), ExprMem(id_c + ExprInt(0x4, 32), 32): ExprMem(id_d + ExprInt(0x8, 32), 32), }) self.assertEqual(sb.eval_expr(ExprInt(1, 32)-ExprInt(1, 32)), ExprInt(0, 32)) ## Test with unknown mem + integer self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0, 32), 32)), ExprMem(ExprInt(0, 32), 32)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(1, 32), 32)), ExprCompose(ExprMem(ExprInt(1, 32), 24), ExprInt(0x44, 8))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(2, 32), 32)), ExprCompose(ExprMem(ExprInt(2, 32), 16), ExprInt(0x3344, 16))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(3, 32), 32)), ExprCompose(ExprMem(ExprInt(3, 32), 8), ExprInt(0x223344, 24))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(4, 32), 32)), ExprInt(0x11223344, 32)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(5, 32), 32)), ExprCompose(ExprInt(0x112233, 24), ExprMem(ExprInt(8, 32), 8))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(6, 32), 32)), ExprCompose(ExprInt(0x1122, 16), ExprMem(ExprInt(8, 32), 16))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(7, 32), 32)), ExprCompose(ExprInt(0x11, 8), ExprMem(ExprInt(8, 32), 24))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(8, 32), 32)), ExprMem(ExprInt(8, 32), 32)) ## Test with unknown mem + integer self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x50, 32), 32)), ExprMem(ExprInt(0x50, 32), 32)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x51, 32), 32)), ExprCompose(ExprMem(ExprInt(0x51, 32), 24), ExprInt(0x44, 8))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x52, 32), 32)), ExprCompose(ExprMem(ExprInt(0x52, 32), 16), ExprInt(0x3344, 16))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x53, 32), 32)), ExprCompose(ExprMem(ExprInt(0x53, 32), 8), ExprInt(0x223344, 24))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x54, 32), 32)), ExprInt(0x11223344, 32)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x55, 32), 32)), ExprCompose(ExprInt(0x112233, 24), ExprMem(ExprInt(0x58, 32), 8))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x56, 32), 32)), ExprCompose(ExprInt(0x1122, 16), ExprMem(ExprInt(0x58, 32), 16))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x57, 32), 32)), ExprCompose(ExprInt(0x11, 8), ExprMem(ExprInt(0x58, 32), 24))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x58, 32), 32)), ExprMem(ExprInt(0x58, 32), 32)) ## Test with unknown mem + id self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x1D, 32), 32)), ExprCompose(ExprMem(ExprInt(0x1D, 32), 24), id_x[:8])) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x1E, 32), 32)), ExprCompose(ExprMem(ExprInt(0x1E, 32), 16), id_x[:16])) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x1F, 32), 32)), ExprCompose(ExprMem(ExprInt(0x1F, 32), 8), id_x[:24])) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x20, 32), 32)), id_x) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x21, 32), 32)), ExprCompose(id_x[8:], ExprMem(ExprInt(0x24, 32), 8))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x22, 32), 32)), ExprCompose(id_x[16:], ExprMem(ExprInt(0x24, 32), 16))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x23, 32), 32)), ExprCompose(id_x[24:], ExprMem(ExprInt(0x24, 32), 24))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x24, 32), 32)), ExprMem(ExprInt(0x24, 32), 32)) ## Partial read self.assertEqual(sb.eval_expr(ExprMem(ExprInt(4, 32), 8)), ExprInt(0x44, 8)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x20, 32), 8)), id_x[:8]) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x23, 32), 8)), id_x[24:]) ## Merge self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x40, 32), 64)), ExprCompose(id_x, id_a)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x42, 32), 32)), ExprCompose(id_x[16:], id_a[:16])) # Merge memory self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x100, 32), 32)), ExprMem(ExprInt(0x100, 32), 32)) self.assertEqual(sb.eval_expr(ExprMem(id_c + ExprInt(0x2, 32), 32)), ExprMem(id_d + ExprInt(0x6, 32), 32)) ## Func read def custom_func_read(mem): if mem == ExprMem(ExprInt(0x1000, 32), 32): return id_x return mem sb.func_read = custom_func_read ## Unmodified read self.assertEqual(sb.eval_expr(ExprMem(ExprInt(4, 32), 8)), ExprInt(0x44, 8)) ## Modified read self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x1000, 32), 32)), id_x) ## Apply_change / eval_ir / apply_expr ## x = a (with a = 0x0) assignblk = AssignBlock({id_x:id_a}) sb.eval_updt_assignblk(assignblk) self.assertEqual(sb.eval_expr(id_x), ExprInt(0, 32)) ## x = a (without replacing 'a' with 0x0) sb.apply_change(id_x, id_a) self.assertEqual(sb.eval_expr(id_x), id_a) ## x = a (with a = 0x0) self.assertEqual(sb.eval_updt_expr(assignblk.dst2ExprAff(id_x)), ExprInt(0, 32)) self.assertEqual(sb.eval_expr(id_x), ExprInt(0, 32)) self.assertEqual(sb.eval_updt_expr(id_x), ExprInt(0, 32)) sb.dump() ## state reads = set() for dst, src in sb.modified(): reads.update(ExprAff(dst, src).get_r()) self.assertEqual(reads, set([ id_x, id_a, ExprMem(id_d + ExprInt(0x4, 32), 32), ExprMem(id_d + ExprInt(0x8, 32), 32), ])) # Erase low id_x byte with 0xFF sb.apply_change(ExprMem(ExprInt(0x20, 32), 8), ExprInt(0xFF, 8)) state = dict(sb.modified(ids=False)) self.assertEqual(state[ExprMem(ExprInt(0x20, 32), 8)], ExprInt(0xFF, 8)) self.assertEqual(state[ExprMem(ExprInt(0x21, 32), 24)], id_x[8:32]) # Erase high id_x byte with 0xEE sb.apply_change(ExprMem(ExprInt(0x23, 32), 8), ExprInt(0xEE, 8)) state = dict(sb.modified(ids=False)) self.assertEqual(state[ExprMem(ExprInt(0x20, 32), 8)], ExprInt(0xFF, 8)) self.assertEqual(state[ExprMem(ExprInt(0x21, 32), 16)], id_x[8:24]) self.assertEqual(state[ExprMem(ExprInt(0x23, 32), 8)], ExprInt(0xEE, 8)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x22, 32), 32)), ExprCompose(id_x[16:24], ExprInt(0xEE, 8), ExprMem(ExprInt(0x24, 32), 16))) # Erase low byte of 0x11223344 with 0xFF at 0x54 sb.apply_change(ExprMem(ExprInt(0x54, 32), 8), ExprInt(0xFF, 8)) # Erase low byte of 0x11223344 with 0xFF at id_a sb.apply_change(ExprMem(id_a + ExprInt(0x1, 32), 8), ExprInt(0xFF, 8)) state = dict(sb.modified(ids=False)) self.assertEqual(state[ExprMem(id_a + ExprInt(0x1, 32), 8)], ExprInt(0xFF, 8)) self.assertEqual(state[ExprMem(id_a + ExprInt(0x2, 32), 16)], ExprInt(0x1122, 16)) # Write uint32_t at 0xFFFFFFFE sb.apply_change(ExprMem(ExprInt(0xFFFFFFFE, 32), 32), ExprInt(0x11223344, 32)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0, 32), 16)), ExprInt(0x1122, 16)) # Revert memory to original value at 0x42 sb.apply_change(ExprMem(ExprInt(0x42, 32), 32), ExprMem(ExprInt(0x42, 32), 32)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x42, 32), 32)), ExprMem(ExprInt(0x42, 32), 32)) # Revert memory to original value at c + 0x2 sb.apply_change(ExprMem(id_c + ExprInt(0x2, 32), 32), ExprMem(id_c + ExprInt(0x2, 32), 32)) self.assertEqual(sb.eval_expr(ExprMem(id_c + ExprInt(0x2, 32), 32)), ExprMem(id_c + ExprInt(0x2, 32), 32)) # Test del symbol del sb.symbols[id_a] sb.dump() del sb.symbols[ExprMem(id_a, 8)] print "*"*40, 'Orig:' sb.dump() sb_cp = sb.symbols.copy() print "*"*40, 'Copy:' sb_cp.dump() # Add symbol at address limit sb.apply_change(ExprMem(ExprInt(0xFFFFFFFE, 32), 32), id_c) sb.dump() found = False for dst, src in sb.symbols.iteritems(): if dst == ExprMem(ExprInt(0xFFFFFFFE, 32), 32) and src == id_c: found = True assert found # Add symbol at address limit sb.apply_change(ExprMem(ExprInt(0x7FFFFFFE, 32), 32), id_c) sb.dump() found = False for dst, src in sb.symbols.iteritems(): if dst == ExprMem(ExprInt(0x7FFFFFFE, 32), 32) and src == id_c: found = True assert found # Add truncated symbol at address limit sb.apply_change(ExprMem(ExprInt(0xFFFFFFFC, 32), 64), id_e) # Revert parts of memory sb.apply_change(ExprMem(ExprInt(0xFFFFFFFC, 32), 16), ExprMem(ExprInt(0xFFFFFFFC, 32), 16)) sb.apply_change(ExprMem(ExprInt(0x2, 32), 16), ExprMem(ExprInt(0x2, 32), 16)) sb.dump() found = False for dst, src in sb.symbols.iteritems(): if dst == ExprMem(ExprInt(0xFFFFFFFE, 32), 32) and src == id_e[16:48]: found = True assert found sb_empty = SymbolicExecutionEngine(ir_x86_32(), {}) sb_empty.dump() # Test memory full print 'full' arch_addr8 = ir_x86_32() # Hack to obtain tiny address space arch_addr8.addrsize = 5 sb_addr8 = SymbolicExecutionEngine(arch_addr8, {}) sb_addr8.dump() # Fulfill memory sb_addr8.apply_change(ExprMem(ExprInt(0, 5), 256), ExprInt(0, 256)) sb_addr8.dump() variables = sb_addr8.symbols.items() assert variables == [(ExprMem(ExprInt(0, 5), 256), ExprInt(0, 256))] print sb_addr8.symbols.symbols_mem sb_addr8.apply_change(ExprMem(ExprInt(0x5, 5), 256), ExprInt(0x123, 256)) sb_addr8.dump() variables = sb_addr8.symbols.items() assert variables == [(ExprMem(ExprInt(0x5, 5), 256), ExprInt(0x123, 256))] print sb_addr8.symbols.symbols_mem print 'dump' sb_addr8.symbols.symbols_mem.dump() sb.dump() try: del sb.symbols.symbols_mem[ExprMem(ExprInt(0xFFFFFFFF, 32), 32)] except KeyError: # ok pass else: raise RuntimeError("Should raise error!") del sb.symbols.symbols_mem[ExprMem(ExprInt(0xFFFFFFFF, 32), 16)] sb.dump() self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0xFFFFFFFE, 32), 32)), ExprCompose(id_e[16:24], ExprMem(ExprInt(0xFFFFFFFF, 32), 16), id_e[40:48])) sb.symbols.symbols_mem.delete_partial(ExprMem(ExprInt(0xFFFFFFFF, 32), 32)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0xFFFFFFFE, 32), 32)), ExprCompose(id_e[16:24], ExprMem(ExprInt(0xFFFFFFFF, 32), 24))) sb.dump() assert ExprMem(ExprInt(0xFFFFFFFE, 32), 8) in sb.symbols assert ExprMem(ExprInt(0xFFFFFFFE, 32), 32) not in sb.symbols assert sb.symbols.symbols_mem.contains_partial(ExprMem(ExprInt(0xFFFFFFFE, 32), 32)) assert not sb.symbols.symbols_mem.contains_partial(ExprMem(ExprInt(0xFFFFFFFF, 32), 8)) assert sb_addr8.symbols.keys() == [ExprMem(ExprInt(0x5, 5), 256)]
from miasm2.expression.parser import str_to_expr from miasm2.expression.expression import ExprInt, ExprId, ExprSlice, ExprMem, \ ExprCond, ExprCompose, ExprOp, ExprAff, ExprLoc, LocKey for expr_test in [ ExprInt(0x12, 32), ExprId('test', 32), ExprLoc(LocKey(12), 32), ExprSlice(ExprInt(0x10, 32), 0, 8), ExprMem(ExprInt(0x10, 32), 32), ExprCond(ExprInt(0x10, 32), ExprInt(0x11, 32), ExprInt(0x12, 32)), ExprCompose(ExprInt(0x10, 16), ExprInt(0x11, 8), ExprInt(0x12, 8)), ExprInt(0x11, 8) + ExprInt(0x12, 8), ExprAff(ExprId('EAX', 32), ExprInt(0x12, 32)), ]: print 'Test: %s' % expr_test assert str_to_expr(repr(expr_test)) == expr_test
def simp_cst_propagation(e_s, expr): """This passe includes: - Constant folding - Common logical identities - Common binary identities """ # merge associatif op args = list(expr.args) op_name = expr.op # simpl integer manip # int OP int => int # TODO: <<< >>> << >> are architecture dependant if op_name in op_propag_cst: while (len(args) >= 2 and args[-1].is_int() and args[-2].is_int()): int2 = args.pop() int1 = args.pop() if op_name == '+': out = int1.arg + int2.arg elif op_name == '*': out = int1.arg * int2.arg elif op_name == '**': out = int1.arg**int2.arg elif op_name == '^': out = int1.arg ^ int2.arg elif op_name == '&': out = int1.arg & int2.arg elif op_name == '|': out = int1.arg | int2.arg elif op_name == '>>': out = int1.arg >> int2.arg elif op_name == '<<': out = int1.arg << int2.arg elif op_name == 'a>>': tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 >> tmp2) elif op_name == '>>>': out = (int1.arg >> (int2.arg % int2.size) | int1.arg << ((int1.size - int2.arg) % int2.size)) elif op_name == '<<<': out = (int1.arg << (int2.arg % int2.size) | int1.arg >> ((int1.size - int2.arg) % int2.size)) elif op_name == '/': out = int1.arg / int2.arg elif op_name == '%': out = int1.arg % int2.arg elif op_name == 'idiv': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) elif op_name == 'imod': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'umod': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'udiv': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) args.append(ExprInt(out, int1.size)) # bsf(int) => int if op_name == "bsf" and args[0].is_int() and args[0].arg != 0: i = 0 while args[0].arg & (1 << i) == 0: i += 1 return ExprInt(i, args[0].size) # bsr(int) => int if op_name == "bsr" and args[0].is_int() and args[0].arg != 0: i = args[0].size - 1 while args[0].arg & (1 << i) == 0: i -= 1 return ExprInt(i, args[0].size) # -(-(A)) => A if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and len(args[0].args) == 1): return args[0].args[0] # -(int) => -int if op_name == '-' and len(args) == 1 and args[0].is_int(): return ExprInt(-int(args[0]), expr.size) # A op 0 =>A if op_name in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1: if args[-1].is_int(0): args.pop() # A - 0 =>A if op_name == '-' and len(args) > 1 and args[-1].is_int(0): assert len( args) == 2 # Op '-' with more than 2 args: SantityCheckError return args[0] # A * 1 =>A if op_name == "*" and len(args) > 1 and args[-1].is_int(1): args.pop() # for cannon form # A * -1 => - A if op_name == "*" and len(args) > 1 and args[-1] == args[-1].mask: args.pop() args[-1] = -args[-1] # op A => A if op_name in [ '+', '*', '^', '&', '|', '>>', '<<', 'a>>', '<<<', '>>>', 'idiv', 'imod', 'umod', 'udiv' ] and len(args) == 1: return args[0] # A-B => A + (-B) if op_name == '-' and len(args) > 1: if len(args) > 2: raise ValueError( 'sanity check fail on expr -: should have one or 2 args ' + '%r %s' % (expr, expr)) return ExprOp('+', args[0], -args[1]) # A op 0 => 0 if op_name in ['&', "*"] and args[1].is_int(0): return ExprInt(0, expr.size) # - (A + B +...) => -A + -B + -C if op_name == '-' and len(args) == 1 and args[0].is_op('+'): args = [-a for a in args[0].args] return ExprOp('+', *args) # -(a?int1:int2) => (a?-int1:-int2) if (op_name == '-' and len(args) == 1 and args[0].is_cond() and args[0].src1.is_int() and args[0].src2.is_int()): int1 = args[0].src1 int2 = args[0].src2 int1 = ExprInt(-int1.arg, int1.size) int2 = ExprInt(-int2.arg, int2.size) return ExprCond(args[0].cond, int1, int2) i = 0 while i < len(args) - 1: j = i + 1 while j < len(args): # A ^ A => 0 if op_name == '^' and args[i] == args[j]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A + (- A) => 0 if op_name == '+' and args[j].is_op("-"): if len(args[j].args) == 1 and args[i] == args[j].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # (- A) + A => 0 if op_name == '+' and args[i].is_op("-"): if len(args[i].args) == 1 and args[j] == args[i].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A | A => A if op_name == '|' and args[i] == args[j]: del args[j] continue # A & A => A if op_name == '&' and args[i] == args[j]: del args[j] continue j += 1 i += 1 if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1: return args[0] # A <<< A.size => A if (op_name in ['<<<', '>>>'] and args[1].is_int() and args[1].arg == args[0].size): return args[0] # A <<< X <<< Y => A <<< (X+Y) (ou <<< >>>) if (op_name in ['<<<', '>>>'] and args[0].is_op() and args[0].op in ['<<<', '>>>']): op1 = op_name op2 = args[0].op if op1 == op2: op_name = op1 args1 = args[0].args[1] + args[1] else: op_name = op2 args1 = args[0].args[1] - args[1] args0 = args[0].args[0] args = [args0, args1] # A >> X >> Y => A >> (X+Y) if (op_name in ['<<', '>>'] and args[0].is_op(op_name)): args = [args[0].args[0], args[0].args[1] + args[1]] # ((A & A.mask) if op_name == "&" and args[-1] == expr.mask: return ExprOp('&', *args[:-1]) # ((A | A.mask) if op_name == "|" and args[-1] == expr.mask: return args[-1] # ! (!X + int) => X - int # TODO # ((A & mask) >> shift) whith mask < 2**shift => 0 if op_name == ">>" and args[1].is_int() and args[0].is_op("&"): if (args[0].args[1].is_int() and 2**args[1].arg > args[0].args[1].arg): return ExprInt(0, args[0].size) # parity(int) => int if op_name == 'parity' and args[0].is_int(): return ExprInt(parity(int(args[0])), 1) # (-a) * b * (-c) * (-d) => (-a) * b * c * d if op_name == "*" and len(args) > 1: new_args = [] counter = 0 for arg in args: if arg.is_op('-') and len(arg.args) == 1: new_args.append(arg.args[0]) counter += 1 else: new_args.append(arg) if counter % 2: return -ExprOp(op_name, *new_args) args = new_args # A << int with A ExprCompose => move index if (op_name == "<<" and args[0].is_compose() and args[1].is_int() and int(args[1]) != 0): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index + shift, index + shift + arg.size)) # filter out expression filter_args = [] min_index = final_size for tmp, start, stop in new_args: if start >= final_size: continue if stop > final_size: tmp = tmp[:tmp.size - (stop - final_size)] stop = final_size filter_args.append(tmp) min_index = min(start, min_index) # create entry 0 assert min_index != 0 tmp = ExprInt(0, min_index) args = [tmp] + filter_args return ExprCompose(*args) # A >> int with A ExprCompose => move index if op_name == ">>" and args[0].is_compose() and args[1].is_int(): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index - shift, index + arg.size - shift)) # filter out expression filter_args = [] max_index = 0 for tmp, start, stop in new_args: if stop <= 0: continue if start < 0: tmp = tmp[-start:] start = 0 filter_args.append(tmp) max_index = max(stop, max_index) # create entry 0 tmp = ExprInt(0, final_size - max_index) args = filter_args + [tmp] return ExprCompose(*args) # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b) if op_name in ['|', '&', '^'] and all([arg.is_compose() for arg in args]): bounds = set() for arg in args: bound = tuple([tmp.size for tmp in arg.args]) bounds.add(bound) if len(bounds) == 1: bound = list(bounds)[0] new_args = [[tmp] for tmp in args[0].args] for sub_arg in args[1:]: for i, tmp in enumerate(sub_arg.args): new_args[i].append(tmp) args = [] for i, arg in enumerate(new_args): args.append(ExprOp(op_name, *arg)) return ExprCompose(*args) # <<<c_rez, >>>c_rez if op_name in [">>>c_rez", "<<<c_rez"]: assert len(args) == 3 dest, rounds, carry_flag = args # Skipped if rounds is 0 if rounds.is_int(0): return dest elif all(arg.is_int() for arg in args): # The expression can be resolved tmp = int(dest) carry_flag = int(carry_flag) size = dest.size tmp_count = (int(rounds) & (0x3f if size == 64 else 0x1f)) % (size + 1) if op_name == ">>>c_rez": while tmp_count != 0: tmp_cf = tmp & 1 tmp = (tmp >> 1) + (carry_flag << (size - 1)) carry_flag = tmp_cf tmp_count -= 1 tmp &= int(dest.mask) elif op_name == "<<<c_rez": while tmp_count != 0: tmp_cf = (tmp >> (size - 1)) & 1 tmp = (tmp << 1) + carry_flag carry_flag = tmp_cf tmp_count -= 1 tmp &= int(dest.mask) else: raise RuntimeError("Unknown operation: %s" % op_name) return ExprInt(tmp, size=dest.size) return ExprOp(op_name, *args)
def simp_cst_propagation(e_s, expr): """This passe includes: - Constant folding - Common logical identities - Common binary identities """ # merge associatif op args = list(expr.args) op_name = expr.op # simpl integer manip # int OP int => int # TODO: <<< >>> << >> are architecture dependant if op_name in op_propag_cst: while (len(args) >= 2 and args[-1].is_int() and args[-2].is_int()): int2 = args.pop() int1 = args.pop() if op_name == '+': out = int1.arg + int2.arg elif op_name == '*': out = int1.arg * int2.arg elif op_name == '**': out =int1.arg ** int2.arg elif op_name == '^': out = int1.arg ^ int2.arg elif op_name == '&': out = int1.arg & int2.arg elif op_name == '|': out = int1.arg | int2.arg elif op_name == '>>': if int(int2) > int1.size: out = 0 else: out = int1.arg >> int2.arg elif op_name == '<<': if int(int2) > int1.size: out = 0 else: out = int1.arg << int2.arg elif op_name == 'a>>': tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) if tmp2 > int1.size: is_signed = int(int1) & (1 << (int1.size - 1)) if is_signed: out = -1 else: out = 0 else: out = mod_size2uint[int1.arg.size](tmp1 >> tmp2) elif op_name == '>>>': shifter = int2.arg % int2.size out = (int1.arg >> shifter) | (int1.arg << (int2.size - shifter)) elif op_name == '<<<': shifter = int2.arg % int2.size out = (int1.arg << shifter) | (int1.arg >> (int2.size - shifter)) elif op_name == '/': out = int1.arg / int2.arg elif op_name == '%': out = int1.arg % int2.arg elif op_name == 'idiv': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) elif op_name == 'imod': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'umod': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'udiv': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) args.append(ExprInt(out, int1.size)) # cnttrailzeros(int) => int if op_name == "cnttrailzeros" and args[0].is_int(): i = 0 while args[0].arg & (1 << i) == 0 and i < args[0].size: i += 1 return ExprInt(i, args[0].size) # cntleadzeros(int) => int if op_name == "cntleadzeros" and args[0].is_int(): if args[0].arg == 0: return ExprInt(args[0].size, args[0].size) i = args[0].size - 1 while args[0].arg & (1 << i) == 0: i -= 1 return ExprInt(expr.size - (i + 1), args[0].size) # -(-(A)) => A if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and len(args[0].args) == 1): return args[0].args[0] # -(int) => -int if op_name == '-' and len(args) == 1 and args[0].is_int(): return ExprInt(-int(args[0]), expr.size) # A op 0 =>A if op_name in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1: if args[-1].is_int(0): args.pop() # A - 0 =>A if op_name == '-' and len(args) > 1 and args[-1].is_int(0): assert len(args) == 2 # Op '-' with more than 2 args: SantityCheckError return args[0] # A * 1 =>A if op_name == "*" and len(args) > 1 and args[-1].is_int(1): args.pop() # for cannon form # A * -1 => - A if op_name == "*" and len(args) > 1 and args[-1] == args[-1].mask: args.pop() args[-1] = - args[-1] # op A => A if op_name in ['+', '*', '^', '&', '|', '>>', '<<', 'a>>', '<<<', '>>>', 'idiv', 'imod', 'umod', 'udiv'] and len(args) == 1: return args[0] # A-B => A + (-B) if op_name == '-' and len(args) > 1: if len(args) > 2: raise ValueError( 'sanity check fail on expr -: should have one or 2 args ' + '%r %s' % (expr, expr)) return ExprOp('+', args[0], -args[1]) # A op 0 => 0 if op_name in ['&', "*"] and args[-1].is_int(0): return ExprInt(0, expr.size) # - (A + B +...) => -A + -B + -C if op_name == '-' and len(args) == 1 and args[0].is_op('+'): args = [-a for a in args[0].args] return ExprOp('+', *args) # -(a?int1:int2) => (a?-int1:-int2) if (op_name == '-' and len(args) == 1 and args[0].is_cond() and args[0].src1.is_int() and args[0].src2.is_int()): int1 = args[0].src1 int2 = args[0].src2 int1 = ExprInt(-int1.arg, int1.size) int2 = ExprInt(-int2.arg, int2.size) return ExprCond(args[0].cond, int1, int2) i = 0 while i < len(args) - 1: j = i + 1 while j < len(args): # A ^ A => 0 if op_name == '^' and args[i] == args[j]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A + (- A) => 0 if op_name == '+' and args[j].is_op("-"): if len(args[j].args) == 1 and args[i] == args[j].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # (- A) + A => 0 if op_name == '+' and args[i].is_op("-"): if len(args[i].args) == 1 and args[j] == args[i].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A | A => A if op_name == '|' and args[i] == args[j]: del args[j] continue # A & A => A if op_name == '&' and args[i] == args[j]: del args[j] continue j += 1 i += 1 if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1: return args[0] # A <<< A.size => A if (op_name in ['<<<', '>>>'] and args[1].is_int() and args[1].arg == args[0].size): return args[0] # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow if (op_name in ['<<<', '>>>'] and args[0].is_op() and args[0].op in ['<<<', '>>>']): A = args[0].args[0] X = args[0].args[1] Y = args[1] if op_name != args[0].op and e_s(X - Y) == ExprInt(0, X.size): return args[0].args[0] elif X.is_int() and Y.is_int(): new_X = int(X) % expr.size new_Y = int(Y) % expr.size if op_name == args[0].op: rot = (new_X + new_Y) % expr.size op = op_name else: rot = new_Y - new_X op = op_name if rot < 0: rot = - rot op = {">>>": "<<<", "<<<": ">>>"}[op_name] args = [A, ExprInt(rot, expr.size)] op_name = op else: # Do not consider this case, too tricky (overflow on addition / # substraction) pass # A >> X >> Y => A >> (X+Y) if X + Y does not overflow # To be sure, only consider the simplification when X.msb and Y.msb are 0 if (op_name in ['<<', '>>'] and args[0].is_op(op_name)): X = args[0].args[1] Y = args[1] if (e_s(X.msb()) == ExprInt(0, 1) and e_s(Y.msb()) == ExprInt(0, 1)): args = [args[0].args[0], X + Y] # ((var >> int1) << int1) => var & mask # ((var << int1) >> int1) => var & mask if (op_name in ['<<', '>>'] and args[0].is_op() and args[0].op in ['<<', '>>'] and op_name != args[0]): var = args[0].args[0] int1 = args[0].args[1] int2 = args[1] if int1 == int2 and int1.is_int() and int(int1) < expr.size: if op_name == '>>': mask = ExprInt((1 << (expr.size - int(int1))) - 1, expr.size) else: mask = ExprInt( ((1 << int(int1)) - 1) ^ ((1 << expr.size) - 1), expr.size ) ret = var & mask return ret # ((A & A.mask) if op_name == "&" and args[-1] == expr.mask: return ExprOp('&', *args[:-1]) # ((A | A.mask) if op_name == "|" and args[-1] == expr.mask: return args[-1] # ! (!X + int) => X - int # TODO # ((A & mask) >> shift) whith mask < 2**shift => 0 if op_name == ">>" and args[1].is_int() and args[0].is_op("&"): if (args[0].args[1].is_int() and 2 ** args[1].arg > args[0].args[1].arg): return ExprInt(0, args[0].size) # parity(int) => int if op_name == 'parity' and args[0].is_int(): return ExprInt(parity(int(args[0])), 1) # (-a) * b * (-c) * (-d) => (-a) * b * c * d if op_name == "*" and len(args) > 1: new_args = [] counter = 0 for arg in args: if arg.is_op('-') and len(arg.args) == 1: new_args.append(arg.args[0]) counter += 1 else: new_args.append(arg) if counter % 2: return -ExprOp(op_name, *new_args) args = new_args # A << int with A ExprCompose => move index if (op_name == "<<" and args[0].is_compose() and args[1].is_int() and int(args[1]) != 0): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index+shift, index+shift+arg.size)) # filter out expression filter_args = [] min_index = final_size for tmp, start, stop in new_args: if start >= final_size: continue if stop > final_size: tmp = tmp[:tmp.size - (stop - final_size)] stop = final_size filter_args.append(tmp) min_index = min(start, min_index) # create entry 0 assert min_index != 0 tmp = ExprInt(0, min_index) args = [tmp] + filter_args return ExprCompose(*args) # A >> int with A ExprCompose => move index if op_name == ">>" and args[0].is_compose() and args[1].is_int(): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index-shift, index+arg.size-shift)) # filter out expression filter_args = [] max_index = 0 for tmp, start, stop in new_args: if stop <= 0: continue if start < 0: tmp = tmp[-start:] start = 0 filter_args.append(tmp) max_index = max(stop, max_index) # create entry 0 tmp = ExprInt(0, final_size - max_index) args = filter_args + [tmp] return ExprCompose(*args) # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b) if op_name in ['|', '&', '^'] and all([arg.is_compose() for arg in args]): bounds = set() for arg in args: bound = tuple([tmp.size for tmp in arg.args]) bounds.add(bound) if len(bounds) == 1: bound = list(bounds)[0] new_args = [[tmp] for tmp in args[0].args] for sub_arg in args[1:]: for i, tmp in enumerate(sub_arg.args): new_args[i].append(tmp) args = [] for i, arg in enumerate(new_args): args.append(ExprOp(op_name, *arg)) return ExprCompose(*args) return ExprOp(op_name, *args)
def apply_expr_on_state_visit_cache(self, expr, state, cache, level=0): """ Deep First evaluate nodes: 1. evaluate node's sons 2. simplify """ if expr in cache: ret = cache[expr] elif expr in state: return state[expr] elif expr.is_int(): ret = expr elif expr.is_id(): if isinstance(expr.name, asmblock.asm_label) and expr.name.offset is not None: ret = ExprInt(expr.name.offset, expr.size) elif expr in self.regstop: ret = exprid_top(expr) else: ret = state.get(expr, expr) elif expr.is_mem(): ret = self.manage_mem(expr, state, cache, level) elif expr.is_cond(): cond = self.apply_expr_on_state_visit_cache( expr.cond, state, cache, level + 1) src1 = self.apply_expr_on_state_visit_cache( expr.src1, state, cache, level + 1) src2 = self.apply_expr_on_state_visit_cache( expr.src2, state, cache, level + 1) if cond.is_id(TOPSTR) or src1.is_id(TOPSTR) or src2.is_id(TOPSTR): ret = exprid_top(expr) else: ret = ExprCond(cond, src1, src2) elif expr.is_slice(): arg = self.apply_expr_on_state_visit_cache(expr.arg, state, cache, level + 1) if arg.is_id(TOPSTR): ret = exprid_top(expr) else: ret = ExprSlice(arg, expr.start, expr.stop) elif expr.is_op(): args = [] for oarg in expr.args: arg = self.apply_expr_on_state_visit_cache( oarg, state, cache, level + 1) assert oarg.size == arg.size if arg.is_id(TOPSTR): return exprid_top(expr) args.append(arg) ret = ExprOp(expr.op, *args) elif expr.is_compose(): args = [] for arg in expr.args: arg = self.apply_expr_on_state_visit_cache( arg, state, cache, level + 1) if arg.is_id(TOPSTR): return exprid_top(expr) args.append(arg) ret = ExprCompose(*args) else: raise TypeError("Unknown expr type") ret = self.expr_simp(ret) assert expr.size == ret.size cache[expr] = ret return ret
def from_ExprOp(self, expr): if len(expr.args) == 1: if expr.op == 'parity': arg = expr.args[0] out = self.from_expr(arg) if arg.size <= self.NATIVE_INT_MAX_SIZE: out = "(%s&%s)" % (out, self._size2mask(arg.size)) else: out = 'bignum_mask(%s, 8)' % (out, 8) out = 'bignum_to_uint64(%s)' % out out = 'parity(%s)' % out return out elif expr.op.startswith("zeroExt_"): arg = expr.args[0] if expr.size == arg.size: return arg return self.from_expr( ExprCompose(arg, ExprInt(0, expr.size - arg.size))) elif expr.op.startswith("signExt_"): arg = expr.args[0] if expr.size == arg.size: return arg add_size = expr.size - arg.size new_expr = ExprCompose( arg, ExprCond(arg.msb(), ExprInt(size2mask(add_size), add_size), ExprInt(0, add_size))) return self.from_expr(new_expr) elif expr.op in ['cntleadzeros', 'cnttrailzeros']: arg = expr.args[0] out = self.from_expr(arg) if arg.size <= self.NATIVE_INT_MAX_SIZE: out = "%s(0x%x, %s)" % (expr.op, expr.args[0].size, out) else: out = "bignum_%s(%s, %d)" % (expr.op, out, arg.size) return out elif expr.op == '!': arg = expr.args[0] out = self.from_expr(arg) if expr.size <= self.NATIVE_INT_MAX_SIZE: out = "(~ %s)&%s" % (out, self._size2mask(arg.size)) else: out = "bignum_not(%s)" % out out = "bignum_mask(%s, expr.size)" % out return out elif expr.op in [ "ftan", "frndint", "f2xm1", "fsin", "fsqrt", "fabs", "fcos", "fchs", ]: return "fpu_%s%d(%s)" % ( expr.op, expr.size, self.from_expr(expr.args[0]), ) elif (expr.op.startswith("access_") or expr.op.startswith("load_") or expr.op.startswith("fxam_c")): arg = expr.args[0] out = self.from_expr(arg) out = "%s(%s)" % (expr.op, out) return out elif expr.op == "-": arg = expr.args[0] out = self.from_expr(arg) if arg.size <= self.NATIVE_INT_MAX_SIZE: out = "(%s(%s))" % (expr.op, out) out = "(%s&%s)" % (out, self._size2mask(arg.size)) else: out = "bignum_sub(bignum_from_uint64(0), %s)" % out out = "bignum_mask(%s, %d)" % (out, expr.size) return out elif expr.op.startswith("fpround_"): return "%s_fp%d(%s)" % ( expr.op, expr.size, self.from_expr(expr.args[0]), ) elif expr.op == "sint_to_fp": size = expr.size arg = expr.args[0] if size not in [32, 64]: raise RuntimeError("Unsupported size for sint_to_fp: %r" % size) return "%s_%d(%s)" % (expr.op, size, self.from_expr(arg)) elif expr.op.startswith("fp_to_sint"): dest_size = expr.size arg_size = expr.args[0].size if (arg_size, dest_size) in [ (32, 32), (64, 64), (64, 32), ]: func = "fp%d_to_sint%d" % (arg_size, dest_size) else: raise RuntimeError( "Unsupported size for fp_to_sint: %r to %r" % (arg_size, dest_size)) return "%s(%s)" % (func, self.from_expr(expr.args[0])) elif expr.op.startswith("fpconvert_fp"): dest_size = expr.size arg_size = expr.args[0].size if (arg_size, dest_size) in [(32, 64), (64, 32)]: func = "fp%d_to_fp%d" % (arg_size, dest_size) else: raise RuntimeError( "Unsupported size for fpconvert: %r to %r" % (arg_size, dest_size)) return "%s(%s)" % (func, self.from_expr(expr.args[0])) else: raise NotImplementedError('Unknown op: %r' % expr.op) elif len(expr.args) == 2: if expr.op == TOK_EQUAL: return '(((%s&%s) == (%s&%s))?1:0)' % ( self.from_expr(expr.args[0]), self._size2mask(expr.args[0].size), self.from_expr(expr.args[1]), self._size2mask(expr.args[1].size), ) elif expr.op in self.dct_shift: arg0 = self.from_expr(expr.args[0]) arg1 = self.from_expr(expr.args[1]) if expr.size <= self.NATIVE_INT_MAX_SIZE: out = 'SHIFT_%s(%d, %s, %s)' % (self.dct_shift[ expr.op].upper(), expr.args[0].size, arg0, arg1) else: op = {"<<": "lshift", ">>": "rshift", "a>>": "a_rshift"} out = "bignum_%s(%s, bignum_to_uint64(%s))" % (op[expr.op], arg0, arg1) out = "bignum_mask(%s, %d)" % (out, expr.size) return out elif expr.is_associative(): args = [self.from_expr(arg) for arg in expr.args] if expr.size <= self.NATIVE_INT_MAX_SIZE: out = (" %s " % expr.op).join(args) out = "((%s)&%s)" % (out, self._size2mask(expr.size)) else: op_to_bn_func = { "+": "add", "*": "mul", "|": "or", "^": "xor", "&": "and", } args = list(expr.args) out = self.from_expr(args.pop()) while args: out = 'bignum_mask(bignum_%s(%s, %s), %d)' % ( op_to_bn_func[expr.op], out, self.from_expr(args.pop()), expr.size) return out elif expr.op in ['-']: return '(((%s&%s) %s (%s&%s))&%s)' % ( self.from_expr(expr.args[0]), self._size2mask(expr.args[0].size), str( expr.op), self.from_expr( expr.args[1]), self._size2mask(expr.args[1].size), self._size2mask(expr.args[0].size)) elif expr.op in self.dct_rot: arg0 = self.from_expr(expr.args[0]) arg1 = self.from_expr(expr.args[1]) if expr.size <= self.NATIVE_INT_MAX_SIZE: out = '(%s(%s, %s, %s) &%s)' % ( self.dct_rot[expr.op], expr.args[0].size, arg0, arg1, self._size2mask(expr.args[0].size), ) else: op = {">>>": "ror", "<<<": "rol"} out = "bignum_%s(%s, %d, bignum_to_uint64(%s))" % ( op[expr.op], arg0, expr.size, arg1) out = "bignum_mask(%s, %d)" % (out, expr.size) return out elif expr.op == 'x86_cpuid': return "%s(%s, %s)" % (expr.op, self.from_expr( expr.args[0]), self.from_expr(expr.args[1])) elif expr.op.startswith("fcom"): arg0 = self.from_expr(expr.args[0]) arg1 = self.from_expr(expr.args[1]) if not expr.args[0].size <= self.NATIVE_INT_MAX_SIZE: raise ValueError( "Bad semantic: fpu do operations do not support such size" ) out = "fpu_%s(%s, %s)" % (expr.op, arg0, arg1) return out elif expr.op in [ "fadd", "fsub", "fdiv", 'fmul', "fscale", "fprem", "fyl2x", "fpatan" ]: arg0 = self.from_expr(expr.args[0]) arg1 = self.from_expr(expr.args[1]) if not expr.args[0].size <= self.NATIVE_INT_MAX_SIZE: raise ValueError( "Bad semantic: fpu do operations do not support such size" ) out = "fpu_%s%d(%s, %s)" % (expr.op, expr.size, arg0, arg1) return out elif expr.op == "segm": return "segm2addr(jitcpu, %s, %s)" % (self.from_expr( expr.args[0]), self.from_expr(expr.args[1])) elif expr.op in ['udiv', 'umod']: arg0 = self.from_expr(expr.args[0]) arg1 = self.from_expr(expr.args[1]) if expr.size <= self.NATIVE_INT_MAX_SIZE: out = '%s%d(%s, %s)' % (expr.op, expr.args[0].size, arg0, arg1) else: out = "bignum_%s(%s, %s)" % (expr.op, arg0, arg1) out = "bignum_mask(%s, %d)" % (out, expr.size) return out elif expr.op in ['sdiv', 'smod']: arg0 = self.from_expr(expr.args[0]) arg1 = self.from_expr(expr.args[1]) if expr.size <= self.NATIVE_INT_MAX_SIZE: out = '%s%d(%s, %s)' % (expr.op, expr.args[0].size, arg0, arg1) else: out = "bignum_%s(%s, %s, %d)" % (expr.op, arg0, arg1, expr.size) out = "bignum_mask(%s, %d)" % (out, expr.size) return out elif expr.op in ["bcdadd", "bcdadd_cf"]: return "%s_%d(%s, %s)" % (expr.op, expr.args[0].size, self.from_expr(expr.args[0]), self.from_expr(expr.args[1])) else: raise NotImplementedError('Unknown op: %r' % expr.op) elif len(expr.args) >= 3 and expr.is_associative(): # ????? oper = [ '(%s&%s)' % ( self.from_expr(arg), self._size2mask(arg.size), ) for arg in expr.args ] oper = str(expr.op).join(oper) return "((%s)&%s)" % (oper, self._size2mask(expr.args[0].size)) else: raise NotImplementedError('Unknown op: %s' % expr.op)
import pickle from miasm2.expression.expression import ExprInt, ExprAff, ExprId, \ Expr, ExprCompose, ExprMem a = ExprId("test", 8) b = ExprInt(1338, 8) c = a + b d = ExprCompose(a, b) e = ExprMem(a, 32) f = a[:8] aff = ExprAff(a, b) print 'Pickling' out = pickle.dumps((a, b, c, d, e, f, aff)) print 'Unpickling' new_a, new_b, new_c, new_d, new_e, new_f, new_aff = pickle.loads(out) print 'Result' print a, b, c, aff print id(a), id(b), id(c), id(d), id(e), id(f), id(aff) print new_a, new_b, new_c, new_d, new_e, new_f, new_aff print id(new_a), id(new_b), id(new_c), id(new_d), id(new_e), id(new_f), id(new_aff) assert a == new_a assert b == new_b assert c == new_c assert d == new_d assert e == new_e assert f == new_f assert aff == new_aff