def csetm(ir, instr, arg1, arg2): e = [] cond_expr = cond2expr[arg2.name] e.append(m2_expr.ExprAff(arg1, m2_expr.ExprCond(cond_expr, m2_expr.ExprInt_from( arg1, -1), m2_expr.ExprInt_from(arg1, 0)))) return e, []
def get_mem_overlapping(self, e, eval_cache=None): if eval_cache is None: eval_cache = {} if not isinstance(e, m2_expr.ExprMem): raise ValueError('mem overlap bad arg') ov = [] # suppose max mem size is 64 bytes, compute all reachable addresses to_test = [] base_ptr = self.expr_simp(e.arg) for i in xrange(-7, e.size / 8): ex = self.expr_simp( self.eval_expr(base_ptr + m2_expr.ExprInt_from(e.arg, i), eval_cache)) to_test.append((i, ex)) for i, x in to_test: if not x in self.symbols.symbols_mem: continue ex = self.expr_simp(self.eval_expr(e.arg - x, eval_cache)) if not isinstance(ex, m2_expr.ExprInt): raise ValueError('ex is not ExprInt') ptr_diff = int32(ex.arg) if ptr_diff >= self.symbols.symbols_mem[x][1].size / 8: # print "too long!" continue ov.append((i, self.symbols.symbols_mem[x][0])) return ov
def csinc(ir, instr, arg1, arg2, arg3, arg4): e = [] cond_expr = cond2expr[arg4.name] e.append(m2_expr.ExprAff(arg1, m2_expr.ExprCond(cond_expr, arg2, arg3 + m2_expr.ExprInt_from(arg3, 1)))) return e, []
def gen_pc_update(self, c, l): c.irs.append( AssignBlock([ m2_expr.ExprAff(self.pc, m2_expr.ExprInt_from(self.pc, l.offset)) ])) c.lines.append(l)
def substract_mems(self, a, b): ex = b.arg - a.arg ex = self.expr_simp(self.eval_expr(ex, {})) if not isinstance(ex, m2_expr.ExprInt): return None ptr_diff = int(int32(ex.arg)) out = [] if ptr_diff < 0: # [a ] #[b ]XXX sub_size = b.size + ptr_diff * 8 if sub_size >= a.size: pass else: ex = m2_expr.ExprOp('+', a.arg, m2_expr.ExprInt_from(a.arg, sub_size / 8)) ex = self.expr_simp(self.eval_expr(ex, {})) rest_ptr = ex rest_size = a.size - sub_size val = self.symbols[a][sub_size:a.size] out = [(m2_expr.ExprMem(rest_ptr, rest_size), val)] else: #[a ] # XXXX[b ]YY #[a ] # XXXX[b ] out = [] # part X if ptr_diff > 0: val = self.symbols[a][0:ptr_diff * 8] out.append((m2_expr.ExprMem(a.arg, ptr_diff * 8), val)) # part Y if ptr_diff * 8 + b.size < a.size: ex = m2_expr.ExprOp('+', b.arg, m2_expr.ExprInt_from(b.arg, b.size / 8)) ex = self.expr_simp(self.eval_expr(ex, {})) rest_ptr = ex rest_size = a.size - (ptr_diff * 8 + b.size) val = self.symbols[a][ptr_diff * 8 + b.size:a.size] out.append((m2_expr.ExprMem(ex, val.size), val)) return out
def eval_ExprId(self, e, eval_cache=None): if eval_cache is None: eval_cache = {} if isinstance(e.name, asmbloc.asm_label) and e.name.offset is not None: return m2_expr.ExprInt_from(e, e.name.offset) if not e in self.symbols: # raise ValueError('unknown symbol %s'% e) return e return self.symbols[e]
def ldp(ir, instr, arg1, arg2, arg3): e = [] addr, updt = get_mem_access(arg3) e.append(m2_expr.ExprAff(arg1, m2_expr.ExprMem(addr, arg1.size))) e.append( m2_expr.ExprAff(arg2, m2_expr.ExprMem(addr + m2_expr.ExprInt_from(addr, arg1.size / 8), arg2.size))) if updt: e.append(updt) return e, []
def ubfm(ir, instr, arg1, arg2, arg3, arg4): e = [] rim, sim = int(arg3.arg), int(arg4.arg) + 1 if sim > rim: res = arg2[rim:sim].zeroExtend(arg1.size) else: shift = m2_expr.ExprInt_from(arg2, arg2.size - rim) res = (arg2[:sim].zeroExtend(arg1.size) << shift) e.append(m2_expr.ExprAff(arg1, res)) return e, []
def bfm(ir, instr, arg1, arg2, arg3, arg4): e = [] rim, sim = int(arg3.arg), int(arg4.arg) + 1 if sim > rim: res = arg2[rim:sim] e.append(m2_expr.ExprAff(arg1[:sim - rim], res)) else: shift_i = arg2.size - rim shift = m2_expr.ExprInt_from(arg2, shift_i) res = arg2[:sim] e.append(m2_expr.ExprAff(arg1[shift_i:shift_i + sim], res)) return e, []
def apply_expr_on_state_visit_cache(self, expr, state, cache, level=0): """ Deep First evaluate nodes: 1. evaluate node's sons 2. simplify """ #print '\t'*level, "Eval:", expr if expr in cache: ret = cache[expr] #print "In cache!", ret elif isinstance(expr, m2_expr.ExprInt): return expr elif isinstance(expr, m2_expr.ExprId): if isinstance(expr.name, asmbloc.asm_label) and expr.name.offset is not None: ret = m2_expr.ExprInt_from(expr, expr.name.offset) else: ret = state.get(expr, expr) elif isinstance(expr, m2_expr.ExprMem): ptr = self.apply_expr_on_state_visit_cache(expr.arg, state, cache, level+1) ret = m2_expr.ExprMem(ptr, expr.size) ret = self.get_mem_state(ret) assert expr.size == ret.size elif isinstance(expr, m2_expr.ExprCond): cond = self.apply_expr_on_state_visit_cache(expr.cond, state, cache, level+1) src1 = self.apply_expr_on_state_visit_cache(expr.src1, state, cache, level+1) src2 = self.apply_expr_on_state_visit_cache(expr.src2, state, cache, level+1) ret = m2_expr.ExprCond(cond, src1, src2) elif isinstance(expr, m2_expr.ExprSlice): arg = self.apply_expr_on_state_visit_cache(expr.arg, state, cache, level+1) ret = m2_expr.ExprSlice(arg, expr.start, expr.stop) elif isinstance(expr, m2_expr.ExprOp): args = [] for oarg in expr.args: arg = self.apply_expr_on_state_visit_cache(oarg, state, cache, level+1) assert oarg.size == arg.size args.append(arg) ret = m2_expr.ExprOp(expr.op, *args) elif isinstance(expr, m2_expr.ExprCompose): args = [] for arg in expr.args: args.append(self.apply_expr_on_state_visit_cache(arg, state, cache, level+1)) ret = m2_expr.ExprCompose(*args) else: raise TypeError("Unknown expr type") #print '\t'*level, "Result", ret ret = self.expr_simp(ret) #print '\t'*level, "Result simpl", ret assert expr.size == ret.size cache[expr] = ret return ret
def extend_arg(dst, arg): if not isinstance(arg, m2_expr.ExprOp): return arg op, (reg, shift) = arg.op, arg.args if op == 'SXTW': base = reg.signExtend(dst.size) else: base = reg.zeroExtend(dst.size) out = base << (shift.zeroExtend(dst.size) & m2_expr.ExprInt_from(dst, dst.size - 1)) return out
def simp_add_mul(expr_simp, expr): "Naive Simplification: a + a + a == a * 3" # Match the expected form ## isinstance(expr, m2_expr.ExprOp) is not needed: simplifications are ## attached to expression types if expr.op == "+" and \ len(expr.args) == 3 and \ expr.args.count(expr.args[0]) == len(expr.args): # Effective simplification return m2_expr.ExprOp("*", expr.args[0], m2_expr.ExprInt_from(expr.args[0], 3)) else: # Do not simplify return expr
def merge_sliceto_slice(args): sources = {} non_slice = {} sources_int = {} for a in args: if isinstance(a[0], m2_expr.ExprInt): # sources_int[a.start] = a # copy ExprInt because we will inplace modify arg just below # /!\ TODO XXX never ever modify inplace args... sources_int[a[1]] = (m2_expr.ExprInt_fromsize( a[2] - a[1], a[0].arg.__class__(a[0].arg)), a[1], a[2]) elif isinstance(a[0], m2_expr.ExprSlice): if not a[0].arg in sources: sources[a[0].arg] = [] sources[a[0].arg].append(a) else: non_slice[a[1]] = a # find max stop to determine size max_size = None for a in args: if max_size is None or max_size < a[2]: max_size = a[2] # first simplify all num slices final_sources = [] sorted_s = [] for x in sources_int.values(): x = list(x) # mask int v = x[0].arg & ((1 << (x[2] - x[1])) - 1) x[0] = m2_expr.ExprInt_from(x[0], v) x = tuple(x) sorted_s.append((x[1], x)) sorted_s.sort() while sorted_s: start, v = sorted_s.pop() out = [m2_expr.ExprInt(v[0].arg), v[1], v[2]] size = v[2] - v[1] while sorted_s: if sorted_s[-1][1][2] != start: break s_start, s_stop = sorted_s[-1][1][1], sorted_s[-1][1][2] size += s_stop - s_start a = m2_expr.mod_size2uint[size]((int(out[0].arg) << (out[1] - s_start)) + int(sorted_s[-1][1][0].arg)) out[0] = m2_expr.ExprInt(a) sorted_s.pop() out[1] = s_start out[0] = m2_expr.ExprInt_fromsize(size, out[0].arg) final_sources.append((start, out)) final_sources_int = final_sources # check if same sources have corresponding start/stop # is slice AND is sliceto simp_sources = [] for args in sources.values(): final_sources = [] sorted_s = [] for x in args: sorted_s.append((x[1], x)) sorted_s.sort() while sorted_s: start, v = sorted_s.pop() ee = v[0].arg[v[0].start:v[0].stop] out = ee, v[1], v[2] while sorted_s: if sorted_s[-1][1][2] != start: break if sorted_s[-1][1][0].stop != out[0].start: break start = sorted_s[-1][1][1] # out[0].start = sorted_s[-1][1][0].start o_e, _, o_stop = out o1, o2 = sorted_s[-1][1][0].start, o_e.stop o_e = o_e.arg[o1:o2] out = o_e, start, o_stop # update _size # out[0]._size = out[0].stop-out[0].start sorted_s.pop() out = out[0], start, out[2] final_sources.append((start, out)) simp_sources += final_sources simp_sources += final_sources_int for i, v in non_slice.items(): simp_sources.append((i, v)) simp_sources.sort() simp_sources = [x[1] for x in simp_sources] return simp_sources
def lsl(arg1, arg2, arg3): arg1 = arg2 << (arg3 & m2_expr.ExprInt_from(arg3, arg3.size - 1))
def eval_ExprMem(self, e, eval_cache=None): if eval_cache is None: eval_cache = {} a_val = self.expr_simp(self.eval_expr(e.arg, eval_cache)) if a_val != e.arg: a = self.expr_simp(m2_expr.ExprMem(a_val, size=e.size)) else: a = e if a in self.symbols: return self.symbols[a] tmp = None # test if mem lookup is known if a_val in self.symbols.symbols_mem: tmp = self.symbols.symbols_mem[a_val][0] if tmp is None: v = self.find_mem_by_addr(a_val) if not v: out = [] ov = self.get_mem_overlapping(a, eval_cache) off_base = 0 ov.sort() # ov.reverse() for off, x in ov: # off_base = off * 8 # x_size = self.symbols[x].size if off >= 0: m = min(a.size - off * 8, x.size) ee = m2_expr.ExprSlice(self.symbols[x], 0, m) ee = self.expr_simp(ee) out.append((ee, off_base, off_base + m)) off_base += m else: m = min(a.size - off * 8, x.size) ee = m2_expr.ExprSlice(self.symbols[x], -off * 8, m) ff = self.expr_simp(ee) new_off_base = off_base + m + off * 8 out.append((ff, off_base, new_off_base)) off_base = new_off_base if out: missing_slice = self.rest_slice(out, 0, a.size) for sa, sb in missing_slice: ptr = self.expr_simp( a_val + m2_expr.ExprInt_from(a_val, sa / 8)) mm = m2_expr.ExprMem(ptr, size=sb - sa) mm.is_term = True mm.is_simp = True out.append((mm, sa, sb)) out.sort(key=lambda x: x[1]) # for e, sa, sb in out: # print str(e), sa, sb ee = m2_expr.ExprSlice(m2_expr.ExprCompose(out), 0, a.size) ee = self.expr_simp(ee) return ee if self.func_read and isinstance(a.arg, m2_expr.ExprInt): return self.func_read(a) else: # XXX hack test a.is_term = True return a # bigger lookup if a.size > tmp.size: rest = a.size ptr = a_val out = [] ptr_index = 0 while rest: v = self.find_mem_by_addr(ptr) if v is None: # raise ValueError("cannot find %s in mem"%str(ptr)) val = m2_expr.ExprMem(ptr, 8) v = val diff_size = 8 elif rest >= v.size: val = self.symbols[v] diff_size = v.size else: diff_size = rest val = self.symbols[v][0:diff_size] val = (val, ptr_index, ptr_index + diff_size) out.append(val) ptr_index += diff_size rest -= diff_size ptr = self.expr_simp( self.eval_expr( m2_expr.ExprOp('+', ptr, m2_expr.ExprInt_from(ptr, v.size / 8)), eval_cache)) e = self.expr_simp(m2_expr.ExprCompose(out)) return e # part lookup tmp = self.expr_simp(m2_expr.ExprSlice(self.symbols[tmp], 0, a.size)) return tmp
def lsr(arg1, arg2, arg3): arg1 = arg2 >> (arg3 & m2_expr.ExprInt_from(arg3, arg3.size - 1))
"Naive Simplification: a + a + a == a * 3" # Match the expected form ## isinstance(expr, m2_expr.ExprOp) is not needed: simplifications are ## attached to expression types if expr.op == "+" and \ len(expr.args) == 3 and \ expr.args.count(expr.args[0]) == len(expr.args): # Effective simplification return m2_expr.ExprOp("*", expr.args[0], m2_expr.ExprInt_from(expr.args[0], 3)) else: # Do not simplify return expr a = m2_expr.ExprId('a') base_expr = a + a + a print "Without adding the simplification:" print "\t%s = %s" % (base_expr, expr_simp(base_expr)) # Enable pass expr_simp.enable_passes({m2_expr.ExprOp: [simp_add_mul]}) print "After adding the simplification:" print "\t%s = %s" % (base_expr, expr_simp(base_expr)) # Automatic fail assert(expr_simp(base_expr) == m2_expr.ExprOp("*", a, m2_expr.ExprInt_from(a, 3)))
def get_asm_offset(self, x): return m2_expr.ExprInt_from(x, self.offset)
def set_pc(ir_arch, src): dst = ir_arch.jit_pc if not isinstance(src, m2_expr.Expr): src = m2_expr.ExprInt_from(dst, src) e = m2_expr.ExprAff(dst, src.zeroExtend(dst.size)) return e
def expr_calc(e): if isinstance(e, m2_expr.ExprId): s = symbols.s[e.name] e = m2_expr.ExprInt_from(e, s.offset) return e
def asr(arg1, arg2, arg3): arg1 = m2_expr.ExprOp('a>>', arg2, (arg3 & m2_expr.ExprInt_from(arg3, arg3.size - 1)))
def get_asm_offset(self, expr): return m2_expr.ExprInt_from(expr, self.offset)
def tbnz(arg1, arg2, arg3): bitmask = m2_expr.ExprInt_from(arg1, 1) << arg2 dst = arg3 if arg1 & bitmask else m2_expr.ExprId(ir.get_next_label(instr), 64) PC = dst ir.IRDst = dst
def label2offset(e): if not isinstance(e, m2_expr.ExprId): return e if not isinstance(e.name, asmbloc.asm_label): return e return m2_expr.ExprInt_from(e, e.name.offset)
# Generate IR for bloc in blocs: ir_arch.add_bloc(bloc) # Get settings settings = depGraphSettingsForm(ir_arch) settings.Execute() label, elements, line_nb = settings.label, settings.elements, settings.line_nb # Simplify affectations for irb in ir_arch.blocs.values(): fix_stack = irb.label.offset is not None and settings.unalias_stack for i, assignblk in enumerate(irb.irs): if fix_stack: stk_high = m2_expr.ExprInt_from(ir_arch.sp, GetSpd(irb.lines[i].offset)) fix_dct = {ir_arch.sp: mn.regs.regs_init[ir_arch.sp] + stk_high} for dst, src in assignblk.items(): del (assignblk[dst]) if fix_stack: src = src.replace_expr(fix_dct) if dst != ir_arch.sp: dst = dst.replace_expr(fix_dct) dst, src = expr_simp(dst), expr_simp(src) assignblk[dst] = src # Get dependency graphs dg = settings.depgraph graphs = dg.get(label, elements, line_nb, set([ir_arch.symbol_pool.getby_offset(func.startEA)]))