def post_add_block(self, block, ir_blocks): IntermediateRepresentation.post_add_block(self, block, ir_blocks) new_irblocks = [] for irb in ir_blocks: pc_val = None lr_val = None for assignblk in irb: pc_val = assignblk.get(self.arch.regs.PC, pc_val) lr_val = assignblk.get(self.arch.regs.RA, lr_val) if pc_val is None or lr_val is None: new_irblocks.append(irb) continue if lr_val.is_loc(): offset = self.symbol_pool.loc_key_to_offset(lr_val.loc_key) if offset is not None: lr_val = ExprInt(offset, 32) if not lr_val.is_int(): continue instr = block.lines[-2] if int(lr_val) != instr.offset + 8: raise ValueError("Wrong arg") # CALL lbl = block.get_next() new_lbl = self.gen_label() irs = self.call_effects(pc_val, instr) irs.append(AssignBlock([ExprAff(self.IRDst, ExprId(lbl, size=self.pc.size))], instr)) new_irblocks.append(IRBlock(new_lbl, irs)) new_irblocks.append(irb.set_dst(ExprId(new_lbl, size=self.pc.size))) return new_irblocks
def offset_to_ptr(base, offset): """ Return an expression representing the @base + @offset @base: symbolic base address @offset: relative offset integer to the @base address """ if base.is_id(INTERNAL_INTBASE_NAME): ptr = ExprInt(offset, base.size) elif offset == 0: ptr = base else: ptr = base + ExprInt(offset, base.size) return ptr.canonize()
def simp_add_multiple(_, expr): """ X + X => 2 * X X + X * int1 => X * (1 + int1) X * int1 + (- X) => X * (int1 - 1) X + (X << int1) => X * (1 + 2 ** int1) Correct even if addition overflow/underflow """ if not expr.is_op('+'): return expr # Extract each argument and its counter operands = {} for arg in expr.args: if arg.is_op('*') and arg.args[1].is_int(): base_expr, factor = arg.args operands[base_expr] = operands.get(base_expr, 0) + int(factor) elif arg.is_op('<<') and arg.args[1].is_int(): base_expr, factor = arg.args operands[base_expr] = operands.get(base_expr, 0) + 2 ** int(factor) elif arg.is_op("-"): arg = arg.args[0] if arg.is_op('<<') and arg.args[1].is_int(): base_expr, factor = arg.args operands[base_expr] = operands.get(base_expr, 0) - (2 ** int(factor)) else: operands[arg] = operands.get(arg, 0) - 1 else: operands[arg] = operands.get(arg, 0) + 1 out = [] # Best effort to factor common args: # (a + b) * 3 + a + b => (a + b) * 4 # Does not factor: # (a + b) * 3 + 2 * a + b => (a + b) * 4 + a modified = True while modified: modified = False for arg, count in operands.iteritems(): if not arg.is_op('+'): continue components = arg.args if not all(component in operands for component in components): continue counters = set(operands[component] for component in components) if len(counters) != 1: continue counter = counters.pop() for component in components: del operands[component] operands[arg] += counter modified = True break for arg, count in operands.iteritems(): if count == 0: continue if count == 1: out.append(arg) continue out.append(arg * ExprInt(count, expr.size)) if len(out) == len(expr.args): # No reductions return expr if not out: return ExprInt(0, expr.size) if len(out) == 1: return out[0] return ExprOp('+', *out)
def test_maddu(self): """Test MADDU execution""" # MADDU Rn,Rm exec_instruction("MADDU R0, R1", [(ExprId("R0", 32), ExprInt(0x2, 32)), (ExprId("R1", 32), ExprInt(0xFFFFFFFF, 32)), (ExprId("HI", 32), ExprInt(0, 32)), (ExprId("LO", 32), ExprInt(0, 32))], [(ExprId("HI", 32), ExprInt(0x1, 32)), (ExprId("LO", 32), ExprInt(0xFFFFFFFE, 32))]) exec_instruction("MADDU R0, R1", [(ExprId("R0", 32), ExprInt(0x2, 32)), (ExprId("R1", 32), ExprInt(0xFFFFFFFF, 32)), (ExprId("HI", 32), ExprInt(1, 32)), (ExprId("LO", 32), ExprInt(1, 32))], [(ExprId("HI", 32), ExprInt(0x1, 32)), (ExprId("LO", 32), ExprInt(0xFFFFFFFF, 32))])
def from_ExprOp(self, expr): if len(expr.args) == 1: if expr.op == 'parity': arg = expr.args[0] out = self.from_expr(arg) if arg.size <= self.NATIVE_INT_MAX_SIZE: out = "(%s&%s)" % (out, self._size2mask(arg.size)) else: out = 'bignum_mask(%s, 8)' % (out, 8) out = 'bignum_to_uint64(%s)' % out out = 'parity(%s)' % out return out elif expr.op.startswith("zeroExt_"): arg = expr.args[0] if expr.size == arg.size: return arg return self.from_expr(ExprCompose(arg, ExprInt(0, expr.size - arg.size))) elif expr.op.startswith("signExt_"): arg = expr.args[0] if expr.size == arg.size: return arg add_size = expr.size - arg.size new_expr = ExprCompose( arg, ExprCond( arg.msb(), ExprInt(size2mask(add_size), add_size), ExprInt(0, add_size) ) ) return self.from_expr(new_expr) elif expr.op in ['cntleadzeros', 'cnttrailzeros']: arg = expr.args[0] out = self.from_expr(arg) if arg.size <= self.NATIVE_INT_MAX_SIZE: out = "%s(0x%x, %s)" % (expr.op, expr.args[0].size, out) else: out = "bignum_%s(%s, %d)" % (expr.op, out, arg.size) return out elif expr.op == '!': arg = expr.args[0] out = self.from_expr(arg) if expr.size <= self.NATIVE_INT_MAX_SIZE: out = "(~ %s)&%s" % (out, self._size2mask(arg.size)) else: out = "bignum_not(%s)" % out out = "bignum_mask(%s, expr.size)" % out return out elif expr.op in [ "ftan", "frndint", "f2xm1", "fsin", "fsqrt", "fabs", "fcos", "fchs", ]: return "fpu_%s%d(%s)" % ( expr.op, expr.size, self.from_expr(expr.args[0]), ) elif (expr.op.startswith("access_") or expr.op.startswith("load_") or expr.op.startswith("fxam_c")): arg = expr.args[0] out = self.from_expr(arg) out = "%s(%s)" % (expr.op, out) return out elif expr.op == "-": arg = expr.args[0] out = self.from_expr(arg) if arg.size <= self.NATIVE_INT_MAX_SIZE: out = "(%s(%s))" % (expr.op, out) out = "(%s&%s)" % (out, self._size2mask(arg.size)) else: out = "bignum_sub(bignum_from_uint64(0), %s)" % out out = "bignum_mask(%s, %d)"% (out, expr.size) return out elif expr.op.startswith("fpround_"): return "%s_fp%d(%s)" % ( expr.op, expr.size, self.from_expr(expr.args[0]), ) elif expr.op == "sint_to_fp": size = expr.size arg = expr.args[0] if size not in [32, 64]: raise RuntimeError( "Unsupported size for sint_to_fp: %r" % size ) return "%s_%d(%s)" % (expr.op, size, self.from_expr(arg)) elif expr.op.startswith("fp_to_sint"): dest_size = expr.size arg_size = expr.args[0].size if (arg_size, dest_size) in [ (32, 32), (64, 64), (64, 32), ]: func = "fp%d_to_sint%d" % (arg_size, dest_size) else: raise RuntimeError( "Unsupported size for fp_to_sint: %r to %r" % ( arg_size, dest_size )) return "%s(%s)" % (func, self.from_expr(expr.args[0])) elif expr.op.startswith("fpconvert_fp"): dest_size = expr.size arg_size = expr.args[0].size if (arg_size, dest_size) in [ (32, 64), (64, 32) ]: func = "fp%d_to_fp%d" % (arg_size, dest_size) else: raise RuntimeError( "Unsupported size for fpconvert: %r to %r" % (arg_size, dest_size) ) return "%s(%s)" % (func, self.from_expr(expr.args[0])) else: raise NotImplementedError('Unknown op: %r' % expr.op) elif len(expr.args) == 2: if expr.op == TOK_EQUAL: return '(((%s&%s) == (%s&%s))?1:0)' % ( self.from_expr(expr.args[0]), self._size2mask(expr.args[0].size), self.from_expr(expr.args[1]), self._size2mask(expr.args[1].size), ) elif expr.op in self.dct_shift: arg0 = self.from_expr(expr.args[0]) arg1 = self.from_expr(expr.args[1]) if expr.size <= self.NATIVE_INT_MAX_SIZE: out = 'SHIFT_%s(%d, %s, %s)' % ( self.dct_shift[expr.op].upper(), expr.args[0].size, arg0, arg1 ) else: op = { "<<": "lshift", ">>": "rshift", "a>>": "a_rshift" } out = "bignum_%s(%s, bignum_to_uint64(%s))" % ( op[expr.op], arg0, arg1 ) out = "bignum_mask(%s, %d)"% (out, expr.size) return out elif expr.is_associative(): args = [self.from_expr(arg) for arg in expr.args] if expr.size <= self.NATIVE_INT_MAX_SIZE: out = (" %s " % expr.op).join(args) out = "((%s)&%s)" % (out, self._size2mask(expr.size)) else: op_to_bn_func = { "+": "add", "*": "mul", "|": "or", "^": "xor", "&": "and", } args = list(expr.args) out = self.from_expr(args.pop()) while args: out = 'bignum_mask(bignum_%s(%s, %s), %d)' % ( op_to_bn_func[expr.op], out, self.from_expr(args.pop()), expr.size ) return out elif expr.op in ['-']: return '(((%s&%s) %s (%s&%s))&%s)' % ( self.from_expr(expr.args[0]), self._size2mask(expr.args[0].size), str(expr.op), self.from_expr(expr.args[1]), self._size2mask(expr.args[1].size), self._size2mask(expr.args[0].size) ) elif expr.op in self.dct_rot: arg0 = self.from_expr(expr.args[0]) arg1 = self.from_expr(expr.args[1]) if expr.size <= self.NATIVE_INT_MAX_SIZE: out = '(%s(%s, %s, %s) &%s)' % ( self.dct_rot[expr.op], expr.args[0].size, arg0, arg1, self._size2mask(expr.args[0].size), ) else: op = { ">>>": "ror", "<<<": "rol" } out = "bignum_%d(%s, %d, bignum_to_uint64(%s))" % ( op[expr.op], arg0, expr.size, arg1 ) out = "bignum_mask(%s, %d)"% (out, expr.size) return out elif expr.op == 'x86_cpuid': return "%s(%s, %s)" % (expr.op, self.from_expr(expr.args[0]), self.from_expr(expr.args[1])) elif expr.op.startswith("fcom"): arg0 = self.from_expr(expr.args[0]) arg1 = self.from_expr(expr.args[1]) if not expr.args[0].size <= self.NATIVE_INT_MAX_SIZE: raise ValueError("Bad semantic: fpu do operations do not support such size") out = "fpu_%s(%s, %s)" % (expr.op, arg0, arg1) return out elif expr.op in ["fadd", "fsub", "fdiv", 'fmul', "fscale", "fprem", "fyl2x", "fpatan"]: arg0 = self.from_expr(expr.args[0]) arg1 = self.from_expr(expr.args[1]) if not expr.args[0].size <= self.NATIVE_INT_MAX_SIZE: raise ValueError("Bad semantic: fpu do operations do not support such size") out = "fpu_%s%d(%s, %s)" % (expr.op, expr.size, arg0, arg1) return out elif expr.op == "segm": return "segm2addr(jitcpu, %s, %s)" % ( self.from_expr(expr.args[0]), self.from_expr(expr.args[1]) ) elif expr.op in ['udiv', 'umod']: arg0 = self.from_expr(expr.args[0]) arg1 = self.from_expr(expr.args[1]) if expr.size <= self.NATIVE_INT_MAX_SIZE: out = '%s%d(%s, %s)' % ( expr.op, expr.args[0].size, arg0, arg1 ) else: out = "bignum_%s(%s, %s)" % ( expr.op, arg0, arg1 ) out = "bignum_mask(%s, %d)"% (out, expr.size) return out elif expr.op in ['idiv', 'imod']: arg0 = self.from_expr(expr.args[0]) arg1 = self.from_expr(expr.args[1]) if expr.size <= self.NATIVE_INT_MAX_SIZE: out = '%s%d(%s, %s)' % ( expr.op, expr.args[0].size, arg0, arg1 ) else: out = "bignum_%s(%s, %s, %d)" % ( expr.op, arg0, arg1, expr.size ) out = "bignum_mask(%s, %d)"% (out, expr.size) return out elif expr.op in ["bcdadd", "bcdadd_cf"]: return "%s_%d(%s, %s)" % ( expr.op, expr.args[0].size, self.from_expr(expr.args[0]), self.from_expr(expr.args[1]) ) else: raise NotImplementedError('Unknown op: %r' % expr.op) elif len(expr.args) >= 3 and expr.is_associative(): # ????? oper = ['(%s&%s)' % ( self.from_expr(arg), self._size2mask(arg.size), ) for arg in expr.args] oper = str(expr.op).join(oper) return "((%s)&%s)" % ( oper, self._size2mask(expr.args[0].size) ) else: raise NotImplementedError('Unknown op: %s' % expr.op)
print k, v reqs.append((k, v)) all_info.append((addr, reqs)) all_cases = set() symbexec = SymbolicExecutionEngine(ir_arch) for addr, reqs_cond in all_info: out = ['(set-logic QF_ABV)', '(set-info :smt-lib-version 2.0)'] conditions = [] all_ids = set() for expr, value in reqs_cond: all_ids.update(get_expr_ids(expr)) expr_test = ExprCond(expr, ExprInt(1, value.size), ExprInt(0, value.size)) cond = translator_smt2.from_expr(ExprAssign(expr_test, value)) conditions.append(cond) for name in all_ids: out.append("(declare-fun %s () (_ BitVec %d))" % (name, name.size)) if not out: continue out += conditions out.append('(check-sat)') open('out.dot', 'w').write('\n'.join(out)) try: cases = subprocess.check_output([ "/home/serpilliere/tools/stp/stp", "-p", '--SMTLIB2', "out.dot"
def simp_cc_conds(expr_simp, expr): if (expr.is_op("CC_U>=") and test_cc_eq_args(expr, "FLAG_SUB_CF")): expr = ExprCond(ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_U<") and test_cc_eq_args(expr, "FLAG_SUB_CF")): expr = ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args) elif (expr.is_op("CC_NEG") and test_cc_eq_args(expr, "FLAG_SIGN_SUB")): expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_POS") and test_cc_eq_args(expr, "FLAG_SIGN_SUB")): expr = ExprCond(ExprOp(TOK_INF_SIGNED, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_EQ") and test_cc_eq_args(expr, "FLAG_EQ")): arg = expr.args[0].args[0] expr = ExprOp(TOK_EQUAL, arg, ExprInt(0, arg.size)) elif (expr.is_op("CC_NE") and test_cc_eq_args(expr, "FLAG_EQ")): arg = expr.args[0].args[0] expr = ExprCond(ExprOp(TOK_EQUAL, arg, ExprInt(0, arg.size)), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_NE") and test_cc_eq_args(expr, "FLAG_EQ_CMP")): expr = ExprCond(ExprOp(TOK_EQUAL, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_EQ") and test_cc_eq_args(expr, "FLAG_EQ_CMP")): expr = ExprOp(TOK_EQUAL, *expr.args[0].args) elif (expr.is_op("CC_NE") and test_cc_eq_args(expr, "FLAG_EQ_AND")): expr = ExprOp("&", *expr.args[0].args) elif (expr.is_op("CC_EQ") and test_cc_eq_args(expr, "FLAG_EQ_AND")): expr = ExprCond(ExprOp("&", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_S>") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF", "FLAG_EQ_CMP", )): expr = ExprCond(ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_S>") and len(expr.args) == 3 and expr.args[0].is_op("FLAG_SIGN_SUB") and expr.args[2].is_op("FLAG_EQ_CMP") and expr.args[0].args == expr.args[2].args and expr.args[1].is_int(0)): expr = ExprCond(ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_S>=") and test_cc_eq_args(expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF")): expr = ExprCond(ExprOp(TOK_INF_SIGNED, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_S<") and test_cc_eq_args(expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF")): expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_S<=") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF", "FLAG_EQ_CMP", )): expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_S<=") and len(expr.args) == 3 and expr.args[0].is_op("FLAG_SIGN_SUB") and expr.args[2].is_op("FLAG_EQ_CMP") and expr.args[0].args == expr.args[2].args and expr.args[1].is_int(0)): expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_U<=") and test_cc_eq_args( expr, "FLAG_SUB_CF", "FLAG_EQ_CMP", )): expr = ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args) elif (expr.is_op("CC_U>") and test_cc_eq_args( expr, "FLAG_SUB_CF", "FLAG_EQ_CMP", )): expr = ExprCond(ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_S<") and test_cc_eq_args(expr, "FLAG_SIGN_ADD", "FLAG_ADD_OF")): arg0, arg1 = expr.args[0].args expr = ExprOp(TOK_INF_SIGNED, arg0, -arg1) return expr
""" Module to generate C code for a given native @block """ from miasm2.expression.expression import Expr, ExprId, ExprLoc, ExprInt, \ ExprMem, ExprCond, LocKey from miasm2.ir.ir import IRBlock, AssignBlock from miasm2.ir.translators.C import TranslatorC from miasm2.core.asmblock import AsmBlockBad from miasm2.expression.simplifications import expr_simp_high_to_explicit TRANSLATOR_NO_SYMBOL = TranslatorC(loc_db=None) SIZE_TO_MASK = { size: TRANSLATOR_NO_SYMBOL.from_expr(ExprInt(0, size).mask) for size in (1, 2, 3, 7, 8, 16, 32, 64) } class Attributes(object): """ Store an irblock attributes """ def __init__(self, log_mn=False, log_regs=False): self.mem_read = False self.mem_write = False self.set_exception = False self.log_mn = log_mn self.log_regs = log_regs self.instr = None
def custom_func_read(mem): if mem == ExprMem(ExprInt(0x1000, 32), 32): return id_x return mem
def update_flag_nf(arg): return [ExprAssign(nf, ExprOp("FLAG_SIGN_SUB", arg, ExprInt(0, arg.size)))]
def update_flag_subwc_cf(op1, op2, op3): "Compute cf in @res = @op1 + @op2 + @op3" return [ ExprAssign(cf, ExprOp("FLAG_SUBWC_CF", op1, op2, op3) ^ ExprInt(1, 1)) ]
def update_flag_subwc_of(op1, op2, op3): "Compute of in @res = @op1 + @op2 + @op3" return [ExprAssign(of, ExprOp("FLAG_SUBWC_OF", op1, op2, op3))] def update_flag_arith_subwc_co(arg1, arg2, arg3): e = [] e += update_flag_subwc_cf(arg1, arg2, arg3) e += update_flag_subwc_of(arg1, arg2, arg3) return e cond2expr = { 'EQ': ExprOp("CC_EQ", zf), 'NE': ExprOp("CC_NE", zf), 'CS': ExprOp("CC_U>=", cf ^ ExprInt(1, 1)), # inv cf 'CC': ExprOp("CC_U<", cf ^ ExprInt(1, 1)), # inv cf 'MI': ExprOp("CC_NEG", nf), 'PL': ExprOp("CC_POS", nf), 'VS': ExprOp("CC_sOVR", of), 'VC': ExprOp("CC_sNOOVR", of), 'HI': ExprOp("CC_U>", cf ^ ExprInt(1, 1), zf), # inv cf 'LS': ExprOp("CC_U<=", cf ^ ExprInt(1, 1), zf), # inv cf 'GE': ExprOp("CC_S>=", nf, of), 'LT': ExprOp("CC_S<", nf, of), 'GT': ExprOp("CC_S>", nf, of, zf), 'LE': ExprOp("CC_S<=", nf, of, zf), 'AL': ExprInt(1, 1), 'NV': ExprInt(0, 1) }
jitter.run = False jitter.pc = 0 return True myjit.push_uint32_t(0x1337beef) myjit.add_breakpoint(0x1337beef, code_sentinelle) # Run myjit.init_run(run_addr) myjit.continue_run() # Check end assert myjit.run is False # Check resulting state / accessors assert myjit.cpu.EAX == 0 assert myjit.cpu.ECX == 4 # Check eval_expr eax = ExprId("RAX", 64)[:32] imm0, imm4, imm4_64 = ExprInt(0, 32), ExprInt(4, 32), ExprInt(4, 64) memdata = ExprMem(ExprInt(run_addr, 32), len(data) * 8) assert myjit.eval_expr(eax) == imm0 ## Due to ExprAff construction, imm4 is "promoted" to imm4_64 assert myjit.eval_expr(ExprAff(eax, imm4)) == imm4_64 assert myjit.eval_expr(eax) == imm4 ## Changes must be passed on myjit.cpu instance assert myjit.cpu.EAX == 4 ## Memory assert myjit.eval_expr(memdata).arg.arg == int(data[::-1].encode("hex"), 16)
def ast_int2expr(a): return ExprInt(a, 32)
def simp_cc_conds(_, expr): """ High level simplifications. Example: CC_U<(FLAG_SUB_CF(A, B) => A <u B """ if (expr.is_op("CC_U>=") and test_cc_eq_args( expr, "FLAG_SUB_CF" )): expr = ExprCond( ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size)) elif (expr.is_op("CC_U<") and test_cc_eq_args( expr, "FLAG_SUB_CF" )): expr = ExprOp(TOK_INF_UNSIGNED, *expr.args[0].args) elif (expr.is_op("CC_NEG") and test_cc_eq_args( expr, "FLAG_SIGN_SUB" )): expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_POS") and test_cc_eq_args( expr, "FLAG_SIGN_SUB" )): expr = ExprCond( ExprOp(TOK_INF_SIGNED, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_EQ") and test_cc_eq_args( expr, "FLAG_EQ" )): arg = expr.args[0].args[0] expr = ExprOp(TOK_EQUAL, arg, ExprInt(0, arg.size)) elif (expr.is_op("CC_NE") and test_cc_eq_args( expr, "FLAG_EQ" )): arg = expr.args[0].args[0] expr = ExprCond( ExprOp(TOK_EQUAL,arg, ExprInt(0, arg.size)), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_NE") and test_cc_eq_args( expr, "FLAG_EQ_CMP" )): expr = ExprCond( ExprOp(TOK_EQUAL, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_EQ") and test_cc_eq_args( expr, "FLAG_EQ_CMP" )): expr = ExprOp(TOK_EQUAL, *expr.args[0].args) elif (expr.is_op("CC_NE") and test_cc_eq_args( expr, "FLAG_EQ_AND" )): expr = ExprOp("&", *expr.args[0].args) elif (expr.is_op("CC_EQ") and test_cc_eq_args( expr, "FLAG_EQ_AND" )): expr = ExprCond( ExprOp("&", *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_S>") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF", "FLAG_EQ_CMP", )): expr = ExprCond( ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_S>") and len(expr.args) == 3 and expr.args[0].is_op("FLAG_SIGN_SUB") and expr.args[2].is_op("FLAG_EQ_CMP") and expr.args[0].args == expr.args[2].args and expr.args[1].is_int(0)): expr = ExprCond( ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_S>=") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF" )): expr = ExprCond( ExprOp(TOK_INF_SIGNED, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_S<") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF" )): expr = ExprOp(TOK_INF_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_S<=") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF", "FLAG_EQ_CMP", )): expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_S<=") and len(expr.args) == 3 and expr.args[0].is_op("FLAG_SIGN_SUB") and expr.args[2].is_op("FLAG_EQ_CMP") and expr.args[0].args == expr.args[2].args and expr.args[1].is_int(0)): expr = ExprOp(TOK_INF_EQUAL_SIGNED, *expr.args[0].args) elif (expr.is_op("CC_U<=") and test_cc_eq_args( expr, "FLAG_SUB_CF", "FLAG_EQ_CMP", )): expr = ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args) elif (expr.is_op("CC_U>") and test_cc_eq_args( expr, "FLAG_SUB_CF", "FLAG_EQ_CMP", )): expr = ExprCond( ExprOp(TOK_INF_EQUAL_UNSIGNED, *expr.args[0].args), ExprInt(0, expr.size), ExprInt(1, expr.size) ) elif (expr.is_op("CC_S<") and test_cc_eq_args( expr, "FLAG_SIGN_ADD", "FLAG_ADD_OF" )): arg0, arg1 = expr.args[0].args expr = ExprOp(TOK_INF_SIGNED, arg0, -arg1) return expr
def lsl(arg1, arg2, arg3): arg1 = arg2 << (arg3 & ExprInt(arg3.size - 1, arg3.size))
def lsr(arg1, arg2, arg3): arg1 = arg2 >> (arg3 & ExprInt(arg3.size - 1, arg3.size))
def test_ClassDef(self): from miasm2.expression.expression import ExprInt, ExprId, ExprMem, \ ExprCompose, ExprAssign from miasm2.arch.x86.sem import ir_x86_32 from miasm2.core.locationdb import LocationDB from miasm2.ir.symbexec import SymbolicExecutionEngine from miasm2.ir.ir import AssignBlock loc_db = LocationDB() ira = ir_x86_32(loc_db) ircfg = ira.new_ircfg() id_x = ExprId('x', 32) id_a = ExprId('a', 32) id_b = ExprId('b', 32) id_c = ExprId('c', 32) id_d = ExprId('d', 32) id_e = ExprId('e', 64) sb = SymbolicExecutionEngine( ira, { ExprMem(ExprInt(0x4, 32), 8): ExprInt(0x44, 8), ExprMem(ExprInt(0x5, 32), 8): ExprInt(0x33, 8), ExprMem(ExprInt(0x6, 32), 8): ExprInt(0x22, 8), ExprMem(ExprInt(0x7, 32), 8): ExprInt(0x11, 8), ExprMem(ExprInt(0x20, 32), 32): id_x, ExprMem(ExprInt(0x40, 32), 32): id_x, ExprMem(ExprInt(0x44, 32), 32): id_a, ExprMem(ExprInt(0x54, 32), 32): ExprInt(0x11223344, 32), ExprMem(id_a, 32): ExprInt(0x11223344, 32), id_a: ExprInt(0, 32), id_b: ExprInt(0, 32), ExprMem(id_c, 32): ExprMem(id_d + ExprInt(0x4, 32), 32), ExprMem(id_c + ExprInt(0x4, 32), 32): ExprMem(id_d + ExprInt(0x8, 32), 32), }) self.assertEqual(sb.eval_expr(ExprInt(1, 32) - ExprInt(1, 32)), ExprInt(0, 32)) ## Test with unknown mem + integer self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0, 32), 32)), ExprMem(ExprInt(0, 32), 32)) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(1, 32), 32)), ExprCompose(ExprMem(ExprInt(1, 32), 24), ExprInt(0x44, 8))) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(2, 32), 32)), ExprCompose(ExprMem(ExprInt(2, 32), 16), ExprInt(0x3344, 16))) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(3, 32), 32)), ExprCompose(ExprMem(ExprInt(3, 32), 8), ExprInt(0x223344, 24))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(4, 32), 32)), ExprInt(0x11223344, 32)) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(5, 32), 32)), ExprCompose(ExprInt(0x112233, 24), ExprMem(ExprInt(8, 32), 8))) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(6, 32), 32)), ExprCompose(ExprInt(0x1122, 16), ExprMem(ExprInt(8, 32), 16))) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(7, 32), 32)), ExprCompose(ExprInt(0x11, 8), ExprMem(ExprInt(8, 32), 24))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(8, 32), 32)), ExprMem(ExprInt(8, 32), 32)) ## Test with unknown mem + integer self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x50, 32), 32)), ExprMem(ExprInt(0x50, 32), 32)) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(0x51, 32), 32)), ExprCompose(ExprMem(ExprInt(0x51, 32), 24), ExprInt(0x44, 8))) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(0x52, 32), 32)), ExprCompose(ExprMem(ExprInt(0x52, 32), 16), ExprInt(0x3344, 16))) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(0x53, 32), 32)), ExprCompose(ExprMem(ExprInt(0x53, 32), 8), ExprInt(0x223344, 24))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x54, 32), 32)), ExprInt(0x11223344, 32)) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(0x55, 32), 32)), ExprCompose(ExprInt(0x112233, 24), ExprMem(ExprInt(0x58, 32), 8))) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(0x56, 32), 32)), ExprCompose(ExprInt(0x1122, 16), ExprMem(ExprInt(0x58, 32), 16))) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(0x57, 32), 32)), ExprCompose(ExprInt(0x11, 8), ExprMem(ExprInt(0x58, 32), 24))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x58, 32), 32)), ExprMem(ExprInt(0x58, 32), 32)) ## Test with unknown mem + id self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x1D, 32), 32)), ExprCompose(ExprMem(ExprInt(0x1D, 32), 24), id_x[:8])) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(0x1E, 32), 32)), ExprCompose(ExprMem(ExprInt(0x1E, 32), 16), id_x[:16])) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x1F, 32), 32)), ExprCompose(ExprMem(ExprInt(0x1F, 32), 8), id_x[:24])) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x20, 32), 32)), id_x) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x21, 32), 32)), ExprCompose(id_x[8:], ExprMem(ExprInt(0x24, 32), 8))) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(0x22, 32), 32)), ExprCompose(id_x[16:], ExprMem(ExprInt(0x24, 32), 16))) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(0x23, 32), 32)), ExprCompose(id_x[24:], ExprMem(ExprInt(0x24, 32), 24))) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x24, 32), 32)), ExprMem(ExprInt(0x24, 32), 32)) ## Partial read self.assertEqual(sb.eval_expr(ExprMem(ExprInt(4, 32), 8)), ExprInt(0x44, 8)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x20, 32), 8)), id_x[:8]) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x23, 32), 8)), id_x[24:]) ## Merge self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x40, 32), 64)), ExprCompose(id_x, id_a)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x42, 32), 32)), ExprCompose(id_x[16:], id_a[:16])) # Merge memory self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x100, 32), 32)), ExprMem(ExprInt(0x100, 32), 32)) self.assertEqual(sb.eval_expr(ExprMem(id_c + ExprInt(0x2, 32), 32)), ExprMem(id_d + ExprInt(0x6, 32), 32)) ## Func read def custom_func_read(mem): if mem == ExprMem(ExprInt(0x1000, 32), 32): return id_x return mem sb.func_read = custom_func_read ## Unmodified read self.assertEqual(sb.eval_expr(ExprMem(ExprInt(4, 32), 8)), ExprInt(0x44, 8)) ## Modified read self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x1000, 32), 32)), id_x) ## Apply_change / eval_ir / apply_expr ## x = a (with a = 0x0) assignblk = AssignBlock({id_x: id_a}) sb.eval_updt_assignblk(assignblk) self.assertEqual(sb.eval_expr(id_x), ExprInt(0, 32)) ## x = a (without replacing 'a' with 0x0) sb.apply_change(id_x, id_a) self.assertEqual(sb.eval_expr(id_x), id_a) ## x = a (with a = 0x0) self.assertEqual(sb.eval_updt_expr(assignblk.dst2ExprAssign(id_x)), ExprInt(0, 32)) self.assertEqual(sb.eval_expr(id_x), ExprInt(0, 32)) self.assertEqual(sb.eval_updt_expr(id_x), ExprInt(0, 32)) sb.dump() ## state reads = set() for dst, src in sb.modified(): reads.update(ExprAssign(dst, src).get_r()) self.assertEqual( reads, set([ id_x, id_a, ExprMem(id_d + ExprInt(0x4, 32), 32), ExprMem(id_d + ExprInt(0x8, 32), 32), ])) # Erase low id_x byte with 0xFF sb.apply_change(ExprMem(ExprInt(0x20, 32), 8), ExprInt(0xFF, 8)) state = dict(sb.modified(ids=False)) self.assertEqual(state[ExprMem(ExprInt(0x20, 32), 8)], ExprInt(0xFF, 8)) self.assertEqual(state[ExprMem(ExprInt(0x21, 32), 24)], id_x[8:32]) # Erase high id_x byte with 0xEE sb.apply_change(ExprMem(ExprInt(0x23, 32), 8), ExprInt(0xEE, 8)) state = dict(sb.modified(ids=False)) self.assertEqual(state[ExprMem(ExprInt(0x20, 32), 8)], ExprInt(0xFF, 8)) self.assertEqual(state[ExprMem(ExprInt(0x21, 32), 16)], id_x[8:24]) self.assertEqual(state[ExprMem(ExprInt(0x23, 32), 8)], ExprInt(0xEE, 8)) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(0x22, 32), 32)), ExprCompose(id_x[16:24], ExprInt(0xEE, 8), ExprMem(ExprInt(0x24, 32), 16))) # Erase low byte of 0x11223344 with 0xFF at 0x54 sb.apply_change(ExprMem(ExprInt(0x54, 32), 8), ExprInt(0xFF, 8)) # Erase low byte of 0x11223344 with 0xFF at id_a sb.apply_change(ExprMem(id_a + ExprInt(0x1, 32), 8), ExprInt(0xFF, 8)) state = dict(sb.modified(ids=False)) self.assertEqual(state[ExprMem(id_a + ExprInt(0x1, 32), 8)], ExprInt(0xFF, 8)) self.assertEqual(state[ExprMem(id_a + ExprInt(0x2, 32), 16)], ExprInt(0x1122, 16)) # Write uint32_t at 0xFFFFFFFE sb.apply_change(ExprMem(ExprInt(0xFFFFFFFE, 32), 32), ExprInt(0x11223344, 32)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0, 32), 16)), ExprInt(0x1122, 16)) # Revert memory to original value at 0x42 sb.apply_change(ExprMem(ExprInt(0x42, 32), 32), ExprMem(ExprInt(0x42, 32), 32)) self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x42, 32), 32)), ExprMem(ExprInt(0x42, 32), 32)) # Revert memory to original value at c + 0x2 sb.apply_change(ExprMem(id_c + ExprInt(0x2, 32), 32), ExprMem(id_c + ExprInt(0x2, 32), 32)) self.assertEqual(sb.eval_expr(ExprMem(id_c + ExprInt(0x2, 32), 32)), ExprMem(id_c + ExprInt(0x2, 32), 32)) # Test del symbol del sb.symbols[id_a] sb.dump() del sb.symbols[ExprMem(id_a, 8)] print "*" * 40, 'Orig:' sb.dump() sb_cp = sb.symbols.copy() print "*" * 40, 'Copy:' sb_cp.dump() # Add symbol at address limit sb.apply_change(ExprMem(ExprInt(0xFFFFFFFE, 32), 32), id_c) sb.dump() found = False for dst, src in sb.symbols.iteritems(): if dst == ExprMem(ExprInt(0xFFFFFFFE, 32), 32) and src == id_c: found = True assert found # Add symbol at address limit sb.apply_change(ExprMem(ExprInt(0x7FFFFFFE, 32), 32), id_c) sb.dump() found = False for dst, src in sb.symbols.iteritems(): if dst == ExprMem(ExprInt(0x7FFFFFFE, 32), 32) and src == id_c: found = True assert found # Add truncated symbol at address limit sb.apply_change(ExprMem(ExprInt(0xFFFFFFFC, 32), 64), id_e) # Revert parts of memory sb.apply_change(ExprMem(ExprInt(0xFFFFFFFC, 32), 16), ExprMem(ExprInt(0xFFFFFFFC, 32), 16)) sb.apply_change(ExprMem(ExprInt(0x2, 32), 16), ExprMem(ExprInt(0x2, 32), 16)) sb.dump() found = False for dst, src in sb.symbols.iteritems(): if dst == ExprMem(ExprInt(0xFFFFFFFE, 32), 32) and src == id_e[16:48]: found = True assert found sb_empty = SymbolicExecutionEngine(ira) sb_empty.dump() # Test memory full print 'full' arch_addr8 = ir_x86_32(loc_db) ircfg = arch_addr8.new_ircfg() # Hack to obtain tiny address space arch_addr8.addrsize = 5 sb_addr8 = SymbolicExecutionEngine(arch_addr8) sb_addr8.dump() # Fulfill memory sb_addr8.apply_change(ExprMem(ExprInt(0, 5), 256), ExprInt(0, 256)) sb_addr8.dump() variables = sb_addr8.symbols.items() assert variables == [(ExprMem(ExprInt(0, 5), 256), ExprInt(0, 256))] print sb_addr8.symbols.symbols_mem sb_addr8.apply_change(ExprMem(ExprInt(0x5, 5), 256), ExprInt(0x123, 256)) sb_addr8.dump() variables = sb_addr8.symbols.items() assert variables == [(ExprMem(ExprInt(0x5, 5), 256), ExprInt(0x123, 256))] print sb_addr8.symbols.symbols_mem print 'dump' sb_addr8.symbols.symbols_mem.dump() sb.dump() try: del sb.symbols.symbols_mem[ExprMem(ExprInt(0xFFFFFFFF, 32), 32)] except KeyError: # ok pass else: raise RuntimeError("Should raise error!") del sb.symbols.symbols_mem[ExprMem(ExprInt(0xFFFFFFFF, 32), 16)] sb.dump() self.assertEqual( sb.eval_expr(ExprMem(ExprInt(0xFFFFFFFE, 32), 32)), ExprCompose(id_e[16:24], ExprMem(ExprInt(0xFFFFFFFF, 32), 16), id_e[40:48])) sb.symbols.symbols_mem.delete_partial( ExprMem(ExprInt(0xFFFFFFFF, 32), 32)) self.assertEqual( sb.eval_expr(ExprMem(ExprInt(0xFFFFFFFE, 32), 32)), ExprCompose(id_e[16:24], ExprMem(ExprInt(0xFFFFFFFF, 32), 24))) sb.dump() assert ExprMem(ExprInt(0xFFFFFFFE, 32), 8) in sb.symbols assert ExprMem(ExprInt(0xFFFFFFFE, 32), 32) not in sb.symbols assert sb.symbols.symbols_mem.contains_partial( ExprMem(ExprInt(0xFFFFFFFE, 32), 32)) assert not sb.symbols.symbols_mem.contains_partial( ExprMem(ExprInt(0xFFFFFFFF, 32), 8)) assert sb_addr8.symbols.keys() == [ExprMem(ExprInt(0x5, 5), 256)]
def asr(arg1, arg2, arg3): arg1 = ExprOp('a>>', arg2, (arg3 & ExprInt(arg3.size - 1, arg3.size)))
def dst_to_c(self, src): """Translate Expr @src into C code""" if not isinstance(src, Expr): src = ExprInt(src, self.PC.size) return self.id_to_c(src)
def bl(arg1): PC = arg1 ir.IRDst = arg1 LR = ExprInt(instr.offset + instr.l, 64)
import pickle from miasm2.expression.expression import ExprInt, ExprAff, ExprId, \ Expr, ExprCompose, ExprMem a = ExprId("test", 8) b = ExprInt(1338, 8) c = a + b d = ExprCompose(a, b) e = ExprMem(a, 32) f = a[:8] aff = ExprAff(a, b) print 'Pickling' out = pickle.dumps((a, b, c, d, e, f, aff)) print 'Unpickling' new_a, new_b, new_c, new_d, new_e, new_f, new_aff = pickle.loads(out) print 'Result' print a, b, c, aff print id(a), id(b), id(c), id(d), id(e), id(f), id(aff) print new_a, new_b, new_c, new_d, new_e, new_f, new_aff print id(new_a), id(new_b), id(new_c), id(new_d), id(new_e), id(new_f), id(new_aff) assert a == new_a assert b == new_b assert c == new_c assert d == new_d assert e == new_e assert f == new_f assert aff == new_aff
def update_flag_sub_cf(op1, op2): "Compote CF in @op1 - @op2" return [ExprAssign(cf, ExprOp("FLAG_SUB_CF", op1, op2) ^ ExprInt(1, 1))]
r = ExprId("r", 32) x = ExprId("x", 32) y = ExprId("y", 32) u8 = ExprId("u8", 8) zf = ExprId('zf', 1) a_init = ExprId("a_init", 32) b_init = ExprId("b_init", 32) c_init = ExprId("c_init", 32) d_init = ExprId("d_init", 32) r_init = ExprId("r_init", 32) # Return register pc = ExprId("pc", 32) sp = ExprId("sp", 32) CST0 = ExprInt(0x0, 32) CST1 = ExprInt(0x1, 32) CST2 = ExprInt(0x2, 32) CST3 = ExprInt(0x3, 32) CSTX_8 = ExprInt(12, 8) LBL0 = loc_db.add_location("lbl0", 0) LBL1 = loc_db.add_location("lbl1", 1) LBL2 = loc_db.add_location("lbl2", 2) LBL3 = loc_db.add_location("lbl3", 3) LBL4 = loc_db.add_location("lbl4", 4) LBL5 = loc_db.add_location("lbl5", 5) LBL6 = loc_db.add_location("lbl6", 6) LBL7 = loc_db.add_location("lbl7", 7) IRDst = ExprId('IRDst', 32)
def sdiv(arg1, arg2, arg3): if arg3: arg1 = ExprOp('sdiv', arg2, arg3) else: exception_flags = ExprInt(EXCEPT_DIV_BY_ZERO, exception_flags.size)
def test_madd(self): """Test MADD execution""" # MADD Rn,Rm exec_instruction("MADD R0, R1", [(ExprId("R0", 32), ExprInt(0x80, 32)), (ExprId("R1", 32), ExprInt(0xFFFFFFFF, 32)), (ExprId("HI", 32), ExprInt(0, 32)), (ExprId("LO", 32), ExprInt(0, 32))], [(ExprId("HI", 32), ExprInt(0xFFFFFFFF, 32)), (ExprId("LO", 32), ExprInt(0xFFFFFF80, 32))]) exec_instruction("MADD R0, R1", [(ExprId("R0", 32), ExprInt(0x80, 32)), (ExprId("R1", 32), ExprInt(0xFFFFFFFF, 32)), (ExprId("HI", 32), ExprInt(1, 32)), (ExprId("LO", 32), ExprInt(1, 32))], [(ExprId("HI", 32), ExprInt(0xFFFFFFFF, 32)), (ExprId("LO", 32), ExprInt(0xFFFFFF81, 32))])
def tbnz(arg1, arg2, arg3): bitmask = ExprInt(1, arg1.size) << arg2 dst = arg3 if arg1 & bitmask else ExprLoc(ir.get_next_loc_key(instr), 64) PC = dst ir.IRDst = dst
def simp_cst_propagation(e_s, expr): """This passe includes: - Constant folding - Common logical identities - Common binary identities """ # merge associatif op args = list(expr.args) op_name = expr.op # simpl integer manip # int OP int => int # TODO: <<< >>> << >> are architecture dependent if op_name in op_propag_cst: while (len(args) >= 2 and args[-1].is_int() and args[-2].is_int()): int2 = args.pop() int1 = args.pop() if op_name == '+': out = int1.arg + int2.arg elif op_name == '*': out = int1.arg * int2.arg elif op_name == '**': out =int1.arg ** int2.arg elif op_name == '^': out = int1.arg ^ int2.arg elif op_name == '&': out = int1.arg & int2.arg elif op_name == '|': out = int1.arg | int2.arg elif op_name == '>>': if int(int2) > int1.size: out = 0 else: out = int1.arg >> int2.arg elif op_name == '<<': if int(int2) > int1.size: out = 0 else: out = int1.arg << int2.arg elif op_name == 'a>>': tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) if tmp2 > int1.size: is_signed = int(int1) & (1 << (int1.size - 1)) if is_signed: out = -1 else: out = 0 else: out = mod_size2uint[int1.arg.size](tmp1 >> tmp2) elif op_name == '>>>': shifter = int2.arg % int2.size out = (int1.arg >> shifter) | (int1.arg << (int2.size - shifter)) elif op_name == '<<<': shifter = int2.arg % int2.size out = (int1.arg << shifter) | (int1.arg >> (int2.size - shifter)) elif op_name == '/': out = int1.arg / int2.arg elif op_name == '%': out = int1.arg % int2.arg elif op_name == 'sdiv': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) elif op_name == 'smod': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'umod': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'udiv': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) args.append(ExprInt(out, int1.size)) # cnttrailzeros(int) => int if op_name == "cnttrailzeros" and args[0].is_int(): i = 0 while args[0].arg & (1 << i) == 0 and i < args[0].size: i += 1 return ExprInt(i, args[0].size) # cntleadzeros(int) => int if op_name == "cntleadzeros" and args[0].is_int(): if args[0].arg == 0: return ExprInt(args[0].size, args[0].size) i = args[0].size - 1 while args[0].arg & (1 << i) == 0: i -= 1 return ExprInt(expr.size - (i + 1), args[0].size) # -(-(A)) => A if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and len(args[0].args) == 1): return args[0].args[0] # -(int) => -int if op_name == '-' and len(args) == 1 and args[0].is_int(): return ExprInt(-int(args[0]), expr.size) # A op 0 =>A if op_name in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1: if args[-1].is_int(0): args.pop() # A - 0 =>A if op_name == '-' and len(args) > 1 and args[-1].is_int(0): assert len(args) == 2 # Op '-' with more than 2 args: SantityCheckError return args[0] # A * 1 =>A if op_name == "*" and len(args) > 1 and args[-1].is_int(1): args.pop() # for cannon form # A * -1 => - A if op_name == "*" and len(args) > 1 and args[-1] == args[-1].mask: args.pop() args[-1] = - args[-1] # op A => A if op_name in ['+', '*', '^', '&', '|', '>>', '<<', 'a>>', '<<<', '>>>', 'sdiv', 'smod', 'umod', 'udiv'] and len(args) == 1: return args[0] # A-B => A + (-B) if op_name == '-' and len(args) > 1: if len(args) > 2: raise ValueError( 'sanity check fail on expr -: should have one or 2 args ' + '%r %s' % (expr, expr)) return ExprOp('+', args[0], -args[1]) # A op 0 => 0 if op_name in ['&', "*"] and args[-1].is_int(0): return ExprInt(0, expr.size) # - (A + B +...) => -A + -B + -C if op_name == '-' and len(args) == 1 and args[0].is_op('+'): args = [-a for a in args[0].args] return ExprOp('+', *args) # -(a?int1:int2) => (a?-int1:-int2) if (op_name == '-' and len(args) == 1 and args[0].is_cond() and args[0].src1.is_int() and args[0].src2.is_int()): int1 = args[0].src1 int2 = args[0].src2 int1 = ExprInt(-int1.arg, int1.size) int2 = ExprInt(-int2.arg, int2.size) return ExprCond(args[0].cond, int1, int2) i = 0 while i < len(args) - 1: j = i + 1 while j < len(args): # A ^ A => 0 if op_name == '^' and args[i] == args[j]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A + (- A) => 0 if op_name == '+' and args[j].is_op("-"): if len(args[j].args) == 1 and args[i] == args[j].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # (- A) + A => 0 if op_name == '+' and args[i].is_op("-"): if len(args[i].args) == 1 and args[j] == args[i].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A | A => A if op_name == '|' and args[i] == args[j]: del args[j] continue # A & A => A if op_name == '&' and args[i] == args[j]: del args[j] continue j += 1 i += 1 if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1: return args[0] # A <<< A.size => A if (op_name in ['<<<', '>>>'] and args[1].is_int() and args[1].arg == args[0].size): return args[0] # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow if (op_name in ['<<<', '>>>'] and args[0].is_op() and args[0].op in ['<<<', '>>>']): A = args[0].args[0] X = args[0].args[1] Y = args[1] if op_name != args[0].op and e_s(X - Y) == ExprInt(0, X.size): return args[0].args[0] elif X.is_int() and Y.is_int(): new_X = int(X) % expr.size new_Y = int(Y) % expr.size if op_name == args[0].op: rot = (new_X + new_Y) % expr.size op = op_name else: rot = new_Y - new_X op = op_name if rot < 0: rot = - rot op = {">>>": "<<<", "<<<": ">>>"}[op_name] args = [A, ExprInt(rot, expr.size)] op_name = op else: # Do not consider this case, too tricky (overflow on addition / # subtraction) pass # A >> X >> Y => A >> (X+Y) if X + Y does not overflow # To be sure, only consider the simplification when X.msb and Y.msb are 0 if (op_name in ['<<', '>>'] and args[0].is_op(op_name)): X = args[0].args[1] Y = args[1] if (e_s(X.msb()) == ExprInt(0, 1) and e_s(Y.msb()) == ExprInt(0, 1)): args = [args[0].args[0], X + Y] # ((var >> int1) << int1) => var & mask # ((var << int1) >> int1) => var & mask if (op_name in ['<<', '>>'] and args[0].is_op() and args[0].op in ['<<', '>>'] and op_name != args[0]): var = args[0].args[0] int1 = args[0].args[1] int2 = args[1] if int1 == int2 and int1.is_int() and int(int1) < expr.size: if op_name == '>>': mask = ExprInt((1 << (expr.size - int(int1))) - 1, expr.size) else: mask = ExprInt( ((1 << int(int1)) - 1) ^ ((1 << expr.size) - 1), expr.size ) ret = var & mask return ret # ((A & A.mask) if op_name == "&" and args[-1] == expr.mask: return ExprOp('&', *args[:-1]) # ((A | A.mask) if op_name == "|" and args[-1] == expr.mask: return args[-1] # ! (!X + int) => X - int # TODO # ((A & mask) >> shift) with mask < 2**shift => 0 if op_name == ">>" and args[1].is_int() and args[0].is_op("&"): if (args[0].args[1].is_int() and 2 ** args[1].arg > args[0].args[1].arg): return ExprInt(0, args[0].size) # parity(int) => int if op_name == 'parity' and args[0].is_int(): return ExprInt(parity(int(args[0])), 1) # (-a) * b * (-c) * (-d) => (-a) * b * c * d if op_name == "*" and len(args) > 1: new_args = [] counter = 0 for arg in args: if arg.is_op('-') and len(arg.args) == 1: new_args.append(arg.args[0]) counter += 1 else: new_args.append(arg) if counter % 2: return -ExprOp(op_name, *new_args) args = new_args # -(a * b * int) => a * b * (-int) if op_name == "-" and args[0].is_op('*') and args[0].args[-1].is_int(): args = args[0].args return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)])) # A << int with A ExprCompose => move index if (op_name == "<<" and args[0].is_compose() and args[1].is_int() and int(args[1]) != 0): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index+shift, index+shift+arg.size)) # filter out expression filter_args = [] min_index = final_size for tmp, start, stop in new_args: if start >= final_size: continue if stop > final_size: tmp = tmp[:tmp.size - (stop - final_size)] filter_args.append(tmp) min_index = min(start, min_index) # create entry 0 assert min_index != 0 tmp = ExprInt(0, min_index) args = [tmp] + filter_args return ExprCompose(*args) # A >> int with A ExprCompose => move index if op_name == ">>" and args[0].is_compose() and args[1].is_int(): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index-shift, index+arg.size-shift)) # filter out expression filter_args = [] max_index = 0 for tmp, start, stop in new_args: if stop <= 0: continue if start < 0: tmp = tmp[-start:] filter_args.append(tmp) max_index = max(stop, max_index) # create entry 0 tmp = ExprInt(0, final_size - max_index) args = filter_args + [tmp] return ExprCompose(*args) # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b) if op_name in ['|', '&', '^'] and all([arg.is_compose() for arg in args]): bounds = set() for arg in args: bound = tuple([tmp.size for tmp in arg.args]) bounds.add(bound) if len(bounds) == 1: new_args = [[tmp] for tmp in args[0].args] for sub_arg in args[1:]: for i, tmp in enumerate(sub_arg.args): new_args[i].append(tmp) args = [] for i, arg in enumerate(new_args): args.append(ExprOp(op_name, *arg)) return ExprCompose(*args) return ExprOp(op_name, *args)
def adrp(arg1, arg2): arg1 = (PC & ExprInt(0xfffffffffffff000, 64)) + arg2
def simp_slice(e_s, expr): "Slice optimization" # slice(A, 0, a.size) => A if expr.start == 0 and expr.stop == expr.arg.size: return expr.arg # Slice(int) => int if expr.arg.is_int(): total_bit = expr.stop - expr.start mask = (1 << (expr.stop - expr.start)) - 1 return ExprInt(int((expr.arg.arg >> expr.start) & mask), total_bit) # Slice(Slice(A, x), y) => Slice(A, z) if expr.arg.is_slice(): if expr.stop - expr.start > expr.arg.stop - expr.arg.start: raise ValueError('slice in slice: getting more val', str(expr)) return ExprSlice(expr.arg.arg, expr.start + expr.arg.start, expr.start + expr.arg.start + (expr.stop - expr.start)) if expr.arg.is_compose(): # Slice(Compose(A), x) => Slice(A, y) for index, arg in expr.arg.iter_args(): if index <= expr.start and index+arg.size >= expr.stop: return arg[expr.start - index:expr.stop - index] # Slice(Compose(A, B, C), x) => Compose(A, B, C) with truncated A/B/C out = [] for index, arg in expr.arg.iter_args(): # arg is before slice start if expr.start >= index + arg.size: continue # arg is after slice stop elif expr.stop <= index: continue # arg is fully included in slice elif expr.start <= index and index + arg.size <= expr.stop: out.append(arg) continue # arg is truncated at start if expr.start > index: slice_start = expr.start - index else: # arg is not truncated at start slice_start = 0 # a is truncated at stop if expr.stop < index + arg.size: slice_stop = arg.size + expr.stop - (index + arg.size) - slice_start else: slice_stop = arg.size out.append(arg[slice_start:slice_stop]) return ExprCompose(*out) # ExprMem(x, size)[:A] => ExprMem(x, a) # XXXX todo hum, is it safe? if (expr.arg.is_mem() and expr.start == 0 and expr.arg.size > expr.stop and expr.stop % 8 == 0): return ExprMem(expr.arg.ptr, size=expr.stop) # distributivity of slice and & # (a & int)[x:y] => 0 if int[x:y] == 0 if expr.arg.is_op("&") and expr.arg.args[-1].is_int(): tmp = e_s.expr_simp_wrapper(expr.arg.args[-1][expr.start:expr.stop]) if tmp.is_int(0): return tmp # distributivity of slice and exprcond # (a?int1:int2)[x:y] => (a?int1[x:y]:int2[x:y]) # (a?compose1:compose2)[x:y] => (a?compose1[x:y]:compose2[x:y]) if (expr.arg.is_cond() and (expr.arg.src1.is_int() or expr.arg.src1.is_compose()) and (expr.arg.src2.is_int() or expr.arg.src2.is_compose())): src1 = expr.arg.src1[expr.start:expr.stop] src2 = expr.arg.src2[expr.start:expr.stop] return ExprCond(expr.arg.cond, src1, src2) # (a * int)[0:y] => (a[0:y] * int[0:y]) if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int(): args = [e_s.expr_simp_wrapper(a[expr.start:expr.stop]) for a in expr.arg.args] return ExprOp(expr.arg.op, *args) # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size # (a << int)[x:y] => a[x-int:y-int] with x-int >= 0 if (expr.arg.is_op() and expr.arg.op in [">>", "<<"] and expr.arg.args[1].is_int()): arg, shift = expr.arg.args shift = int(shift) if expr.arg.op == ">>": if shift + expr.stop <= arg.size: return arg[expr.start + shift:expr.stop + shift] elif expr.arg.op == "<<": if expr.start - shift >= 0: return arg[expr.start - shift:expr.stop - shift] else: raise ValueError('Bad case') return expr
def svc(arg1): exception_flags = ExprInt(EXCEPT_INT_XX, exception_flags.size) interrupt_num = ExprInt(int(arg1), interrupt_num.size)
def miasm_dis(r2_op, r2_address, r2_buffer, r2_length): """Disassemble an instruction using miasm.""" # Cast radare2 variables rasmop = ffi.cast("RAsmOp_r2m2*", r2_op) opcode = ffi.cast("char*", r2_buffer) # Prepare the opcode opcode = ffi.unpack(opcode, r2_length) # Get the miasm2 machine machine = miasm_machine() if machine is None: return # Disassemble the opcode try: mode = machine.dis_engine().attrib instr = machine.mn().dis(opcode, mode) instr.offset = r2_address if instr.dstflow(): # Remember ExprInt arguments sizes args_size = list() for i in range(len(instr.args)): if isinstance(instr.args[i], ExprInt): args_size.append(instr.args[i].size) else: args_size.append(None) # Adjust arguments values using the instruction offset instr.dstflow2label(AsmSymbolPool()) # Convert label back to ExprInt for i in range(len(instr.args)): if args_size[i] is None: continue if isinstance(instr.args[i], ExprId) and \ isinstance(instr.args[i].name, AsmLabel): addr = str(instr.args[i].name) addr = int(addr.split(":")[1], 16) instr.args[i] = ExprInt(addr, args_size[i]) dis_str = str(instr) dis_len = instr.l except: dis_str = "/!\ Can't disassemble using miasm /!\\" dis_len = 2 # GV: seems fischy ! # Remaining bytes buf_hex = opcode[0:dis_len].encode("hex") # Check buffer sizes if len(dis_str)-1 > 256: dis_str = "/!\ Disassembled instruction is too long /!\\" if len(buf_hex)-1 > 256: buf_hex = buf_hex[:255] # Fill the RAsmOp structure rasmop.size = dis_len rasmop.buf_asm = dis_str rasmop.buf_hex = buf_hex
def simp_cst_propagation(e_s, expr): """This passe includes: - Constant folding - Common logical identities - Common binary identities """ # merge associatif op args = list(expr.args) op_name = expr.op # simpl integer manip # int OP int => int # TODO: <<< >>> << >> are architecture dependant if op_name in op_propag_cst: while (len(args) >= 2 and args[-1].is_int() and args[-2].is_int()): int2 = args.pop() int1 = args.pop() if op_name == '+': out = int1.arg + int2.arg elif op_name == '*': out = int1.arg * int2.arg elif op_name == '**': out =int1.arg ** int2.arg elif op_name == '^': out = int1.arg ^ int2.arg elif op_name == '&': out = int1.arg & int2.arg elif op_name == '|': out = int1.arg | int2.arg elif op_name == '>>': if int(int2) > int1.size: out = 0 else: out = int1.arg >> int2.arg elif op_name == '<<': if int(int2) > int1.size: out = 0 else: out = int1.arg << int2.arg elif op_name == 'a>>': tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) if tmp2 > int1.size: is_signed = int(int1) & (1 << (int1.size - 1)) if is_signed: out = -1 else: out = 0 else: out = mod_size2uint[int1.arg.size](tmp1 >> tmp2) elif op_name == '>>>': shifter = int2.arg % int2.size out = (int1.arg >> shifter) | (int1.arg << (int2.size - shifter)) elif op_name == '<<<': shifter = int2.arg % int2.size out = (int1.arg << shifter) | (int1.arg >> (int2.size - shifter)) elif op_name == '/': out = int1.arg / int2.arg elif op_name == '%': out = int1.arg % int2.arg elif op_name == 'idiv': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) elif op_name == 'imod': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'umod': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'udiv': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) args.append(ExprInt(out, int1.size)) # cnttrailzeros(int) => int if op_name == "cnttrailzeros" and args[0].is_int(): i = 0 while args[0].arg & (1 << i) == 0 and i < args[0].size: i += 1 return ExprInt(i, args[0].size) # cntleadzeros(int) => int if op_name == "cntleadzeros" and args[0].is_int(): if args[0].arg == 0: return ExprInt(args[0].size, args[0].size) i = args[0].size - 1 while args[0].arg & (1 << i) == 0: i -= 1 return ExprInt(expr.size - (i + 1), args[0].size) # -(-(A)) => A if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and len(args[0].args) == 1): return args[0].args[0] # -(int) => -int if op_name == '-' and len(args) == 1 and args[0].is_int(): return ExprInt(-int(args[0]), expr.size) # A op 0 =>A if op_name in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1: if args[-1].is_int(0): args.pop() # A - 0 =>A if op_name == '-' and len(args) > 1 and args[-1].is_int(0): assert len(args) == 2 # Op '-' with more than 2 args: SantityCheckError return args[0] # A * 1 =>A if op_name == "*" and len(args) > 1 and args[-1].is_int(1): args.pop() # for cannon form # A * -1 => - A if op_name == "*" and len(args) > 1 and args[-1] == args[-1].mask: args.pop() args[-1] = - args[-1] # op A => A if op_name in ['+', '*', '^', '&', '|', '>>', '<<', 'a>>', '<<<', '>>>', 'idiv', 'imod', 'umod', 'udiv'] and len(args) == 1: return args[0] # A-B => A + (-B) if op_name == '-' and len(args) > 1: if len(args) > 2: raise ValueError( 'sanity check fail on expr -: should have one or 2 args ' + '%r %s' % (expr, expr)) return ExprOp('+', args[0], -args[1]) # A op 0 => 0 if op_name in ['&', "*"] and args[-1].is_int(0): return ExprInt(0, expr.size) # - (A + B +...) => -A + -B + -C if op_name == '-' and len(args) == 1 and args[0].is_op('+'): args = [-a for a in args[0].args] return ExprOp('+', *args) # -(a?int1:int2) => (a?-int1:-int2) if (op_name == '-' and len(args) == 1 and args[0].is_cond() and args[0].src1.is_int() and args[0].src2.is_int()): int1 = args[0].src1 int2 = args[0].src2 int1 = ExprInt(-int1.arg, int1.size) int2 = ExprInt(-int2.arg, int2.size) return ExprCond(args[0].cond, int1, int2) i = 0 while i < len(args) - 1: j = i + 1 while j < len(args): # A ^ A => 0 if op_name == '^' and args[i] == args[j]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A + (- A) => 0 if op_name == '+' and args[j].is_op("-"): if len(args[j].args) == 1 and args[i] == args[j].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # (- A) + A => 0 if op_name == '+' and args[i].is_op("-"): if len(args[i].args) == 1 and args[j] == args[i].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A | A => A if op_name == '|' and args[i] == args[j]: del args[j] continue # A & A => A if op_name == '&' and args[i] == args[j]: del args[j] continue j += 1 i += 1 if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1: return args[0] # A <<< A.size => A if (op_name in ['<<<', '>>>'] and args[1].is_int() and args[1].arg == args[0].size): return args[0] # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow if (op_name in ['<<<', '>>>'] and args[0].is_op() and args[0].op in ['<<<', '>>>']): A = args[0].args[0] X = args[0].args[1] Y = args[1] if op_name != args[0].op and e_s(X - Y) == ExprInt(0, X.size): return args[0].args[0] elif X.is_int() and Y.is_int(): new_X = int(X) % expr.size new_Y = int(Y) % expr.size if op_name == args[0].op: rot = (new_X + new_Y) % expr.size op = op_name else: rot = new_Y - new_X op = op_name if rot < 0: rot = - rot op = {">>>": "<<<", "<<<": ">>>"}[op_name] args = [A, ExprInt(rot, expr.size)] op_name = op else: # Do not consider this case, too tricky (overflow on addition / # substraction) pass # A >> X >> Y => A >> (X+Y) if X + Y does not overflow # To be sure, only consider the simplification when X.msb and Y.msb are 0 if (op_name in ['<<', '>>'] and args[0].is_op(op_name)): X = args[0].args[1] Y = args[1] if (e_s(X.msb()) == ExprInt(0, 1) and e_s(Y.msb()) == ExprInt(0, 1)): args = [args[0].args[0], X + Y] # ((var >> int1) << int1) => var & mask # ((var << int1) >> int1) => var & mask if (op_name in ['<<', '>>'] and args[0].is_op() and args[0].op in ['<<', '>>'] and op_name != args[0]): var = args[0].args[0] int1 = args[0].args[1] int2 = args[1] if int1 == int2 and int1.is_int() and int(int1) < expr.size: if op_name == '>>': mask = ExprInt((1 << (expr.size - int(int1))) - 1, expr.size) else: mask = ExprInt( ((1 << int(int1)) - 1) ^ ((1 << expr.size) - 1), expr.size ) ret = var & mask return ret # ((A & A.mask) if op_name == "&" and args[-1] == expr.mask: return ExprOp('&', *args[:-1]) # ((A | A.mask) if op_name == "|" and args[-1] == expr.mask: return args[-1] # ! (!X + int) => X - int # TODO # ((A & mask) >> shift) whith mask < 2**shift => 0 if op_name == ">>" and args[1].is_int() and args[0].is_op("&"): if (args[0].args[1].is_int() and 2 ** args[1].arg > args[0].args[1].arg): return ExprInt(0, args[0].size) # parity(int) => int if op_name == 'parity' and args[0].is_int(): return ExprInt(parity(int(args[0])), 1) # (-a) * b * (-c) * (-d) => (-a) * b * c * d if op_name == "*" and len(args) > 1: new_args = [] counter = 0 for arg in args: if arg.is_op('-') and len(arg.args) == 1: new_args.append(arg.args[0]) counter += 1 else: new_args.append(arg) if counter % 2: return -ExprOp(op_name, *new_args) args = new_args # -(a * b * int) => a * b * (-int) if op_name == "-" and args[0].is_op('*') and args[0].args[-1].is_int(): args = args[0].args return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)])) # A << int with A ExprCompose => move index if (op_name == "<<" and args[0].is_compose() and args[1].is_int() and int(args[1]) != 0): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index+shift, index+shift+arg.size)) # filter out expression filter_args = [] min_index = final_size for tmp, start, stop in new_args: if start >= final_size: continue if stop > final_size: tmp = tmp[:tmp.size - (stop - final_size)] filter_args.append(tmp) min_index = min(start, min_index) # create entry 0 assert min_index != 0 tmp = ExprInt(0, min_index) args = [tmp] + filter_args return ExprCompose(*args) # A >> int with A ExprCompose => move index if op_name == ">>" and args[0].is_compose() and args[1].is_int(): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index-shift, index+arg.size-shift)) # filter out expression filter_args = [] max_index = 0 for tmp, start, stop in new_args: if stop <= 0: continue if start < 0: tmp = tmp[-start:] filter_args.append(tmp) max_index = max(stop, max_index) # create entry 0 tmp = ExprInt(0, final_size - max_index) args = filter_args + [tmp] return ExprCompose(*args) # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b) if op_name in ['|', '&', '^'] and all([arg.is_compose() for arg in args]): bounds = set() for arg in args: bound = tuple([tmp.size for tmp in arg.args]) bounds.add(bound) if len(bounds) == 1: new_args = [[tmp] for tmp in args[0].args] for sub_arg in args[1:]: for i, tmp in enumerate(sub_arg.args): new_args[i].append(tmp) args = [] for i, arg in enumerate(new_args): args.append(ExprOp(op_name, *arg)) return ExprCompose(*args) return ExprOp(op_name, *args)