def get_mem_overlapping(self, e, eval_cache={}):
        if not isinstance(e, ExprMem):
            raise ValueError('mem overlap bad arg')
        ov = []
        """
        for k in self.pool:
            if not isinstance(k, ExprMem):
                continue
            ex = ExprOp('-', k.arg, e.arg)
            ex = expr_simp(self.eval_expr(ex, {}))
            if not isinstance(ex, ExprInt):
                continue
            ptr_diff = int32(ex.arg)
            if ptr_diff >=0 and ptr_diff < e.size/8:
                ov.append((-ptr_diff, k))
            elif ptr_diff <0 and ptr_diff + k.size/8>0:
                ov.append((-ptr_diff, k))
        """
        # suppose max mem size is 64 bytes, compute all reachable addresses
        to_test = []
        #comp = {}
        #print("FINDING %s" % e)
        for i in range(-7, e.size // 8):
            ex = expr_simp(
                self.eval_expr(e.arg + ExprInt(uint32(i)), eval_cache))
            #print("%s %s"%(i, ex))
            to_test.append((i, ex))

        for i, x in to_test:
            if not x in self.pool.pool_mem:
                continue

            ex = expr_simp(self.eval_expr(e.arg - x, eval_cache))
            if not isinstance(ex, ExprInt):
                raise ValueError("%s should be ExprInt instead of %s" %
                                 (ex, ex.__class__.__name__))
            ptr_diff = int32(ex.arg)
            #print("ptrdiff %s %s'%(ptr_diff, i))
            if ptr_diff >= self.pool.pool_mem[x][1].get_size() / 8:
                #print("too long!")
                continue
            ov.append((i, self.pool.pool_mem[x][0]))
        #"""
        """
        print(ov)
        if len(ov)>0:
            print("XXXX %s" % [(x[0], str(x[1])) for x in ov])
        """
        return ov
    def get_instr_mod(self, exprs):
        pool_out = {}

        eval_cache = {}

        for e in exprs:
            if not isinstance(e, ExprAff):
                raise TypeError('not affect', str(e))

            src = self.eval_expr(e.src, eval_cache)
            if isinstance(e.dst, ExprMem):
                a = self.eval_expr(e.dst.arg, eval_cache)
                a = expr_simp(a)
                #search already present mem
                tmp = None
                #test if mem lookup is known
                tmp = ExprMem(a, e.dst.size)
                dst = tmp
                if self.func_write and isinstance(dst.arg, ExprInt):
                    self.func_write(self, dst, src, pool_out)
                else:
                    pool_out[dst] = src

            elif isinstance(e.dst, ExprId):
                pool_out[e.dst] = src
            elif isinstance(e.dst, ExprTop):
                raise ValueError("affect in ExprTop")
            else:
                raise ValueError("affected zarb", str(e.dst))

        return pool_out
    def eval_ExprOp(self, e, eval_cache={}):
        args = []
        for a in e.args:
            b = expr_simp(self.eval_expr(a, eval_cache))
            if isinstance(b, ExprTop):
                return ExprTop()
            args.append(b)
        #Very agresive, but should work
        for a in args:
            if isinstance(a, ExprTop):
                return ExprTop()

        for a in args:
            if not isinstance(a, ExprInt):
                return ExprOp(e.op, *args)

        args = [a.arg for a in args]

        types_tab = [type(a) for a in args]
        if types_tab.count(types_tab[0]) != len(
                args) and not e.op in self.op_size_no_check:
            raise ValueError('invalid cast %r %r' % (types_tab, args))

        cast_int = types_tab[0]
        op_size = tab_int_size[types_tab[0]]

        ret_value = self.deal_op[e.op](self, args, op_size, cast_int)
        if isinstance(ret_value, Expr):
            return ret_value
        return ExprInt(cast_int(ret_value))
    def substract_mems(self, a, b):
        ex = ExprOp('-', b.arg, a.arg)
        ex = expr_simp(self.eval_expr(ex, {}))
        if not isinstance(ex, ExprInt):
            return None
        ptr_diff = int(int32(ex.arg))
        out = []
        if ptr_diff < 0:
            #    [a     ]
            #[b      ]XXX

            sub_size = b.size + ptr_diff * 8
            if sub_size >= a.size:
                pass
            else:
                ex = ExprOp('+', a.arg, ExprInt(uint32(sub_size / 8)))
                ex = expr_simp(self.eval_expr(ex, {}))

                rest_ptr = ex
                rest_size = a.size - sub_size

                val = self.pool[a][sub_size:a.size]
                out = [(ExprMem(rest_ptr, rest_size), val)]
        else:
            #[a         ]
            #XXXX[b   ]YY

            #[a     ]
            #XXXX[b     ]

            out = []
            #part X
            if ptr_diff > 0:
                val = self.pool[a][0:ptr_diff * 8]
                out.append((ExprMem(a.arg, ptr_diff * 8), val))
            #part Y
            if ptr_diff * 8 + b.size < a.size:

                ex = ExprOp('+', b.arg, ExprInt(uint32(b.size / 8)))
                ex = expr_simp(self.eval_expr(ex, {}))

                val = self.pool[a][ptr_diff * 8 + b.size:a.size]
                out.append((ExprMem(ex, val.get_size()), val))

        return out
    def eval_ExprSlice(self, e, eval_cache={}):
        arg = expr_simp(self.eval_expr(e.arg, eval_cache))
        if isinstance(arg, ExprTop):
            return ExprTop()

        if isinstance(arg, ExprMem):
            if e.start == 0 and e.stop == arg.size:
                return arg

            return ExprSlice(arg, e.start, e.stop)
        if isinstance(arg, ExprTop):
            return ExprTop()
        if isinstance(arg, ExprId):
            return ExprSlice(arg, e.start, e.stop)
        if isinstance(arg, ExprInt):
            return expr_simp(ExprSlice(arg, e.start, e.stop))
        if isinstance(arg, ExprCompose):
            return ExprSlice(arg, e.start, e.stop)
        return ExprSlice(arg, e.start, e.stop)
 def is_mem_in_target(self, e, t):
     ex = ExprOp('-', e.arg, t.arg)
     ex = expr_simp(self.eval_expr(ex, {}))
     if not isinstance(ex, ExprInt):
         return None
     ptr_diff = int32(ex.arg)
     if ptr_diff < 0:
         return False
     if ptr_diff + e.size / 8 <= t.size / 8:
         return True
     return False
Exemple #7
0
def emul_lines(machine, lines):
    my_eip = None
    for l in lines:
        my_eip = ExprInt(uint32(l.offset))

        args = []
        my_eip.arg += uint32(l.l)
        ex = get_instr_expr(l, my_eip, args)
        my_eip, mem_dst = emul_full_expr(ex, l, my_eip, None, machine)

        for k in machine.pool:
            machine.pool[k] = expr_simp(machine.pool[k])

    return my_eip
    def eval_ExprMem(self, e, eval_cache={}):
        a_val = expr_simp(self.eval_expr(e.arg, eval_cache))
        if isinstance(a_val, ExprTop):
            #XXX hack test
            ee = ExprMem(e.arg, e.size)
            ee.is_term = True
            return ee
        a = expr_simp(ExprMem(a_val, size=e.size))
        if a in self.pool:
            return self.pool[a]
        tmp = None
        #test if mem lookup is known
        """
        for k in self.pool:
            if not isinstance(k, ExprMem):
                continue
            if a_val == k.arg:
                tmp = k
                break
        """
        if a_val in self.pool.pool_mem:
            tmp = self.pool.pool_mem[a_val][0]
        """
        for k in self.pool:
            if not isinstance(k, ExprMem):
                continue
            if a_val == k.arg:
                tmp = k
                break
        """
        if tmp is None:

            v = self.find_mem_by_addr(a_val)
            if not v:
                out = []
                ov = self.get_mem_overlapping(a, eval_cache)
                off_base = 0
                ov.sort()
                ov.reverse()
                for off, x in ov:
                    off_base = off * 8
                    if off >= 0:
                        m = min(a.get_size() - off_base, x.get_size())
                        ee = ExprSlice(self.pool[x], 0, m)
                        ee = expr_simp(ee)
                        out.append((ee, off_base, off_base + ee.get_size()))
                        off_base += ee.get_size()
                    else:
                        m = min(a.get_size() - off * 8, x.get_size())
                        ee = ExprSlice(self.pool[x], -off * 8, m)
                        ee = expr_simp(ee)
                        out.append((ee, off_base, off_base + ee.get_size()))
                        off_base += ee.get_size()
                if out:
                    missing_slice = self.rest_slice(out, 0, a.get_size())
                    for sa, sb in missing_slice:
                        ptr = expr_simp(a_val + ExprInt32(sa / 8))
                        out.append((ExprMem(ptr, size=sb - sa), sa, sb))
                    out = sorted(out, key=lambda x: x[1])
                    #for e, sa, sb in out:
                    #    print("%s %s %s"%(e, sa, sb))
                    ee = ExprSlice(ExprCompose(out), 0, a.get_size())
                    ee = expr_simp(ee)
                    return ee
            if self.func_read and isinstance(a.arg, ExprInt):
                return self.func_read(self, a)
            else:
                #XXX hack test
                a.is_term = True
                return a
        #eq lookup
        if a.size == tmp.size:
            return self.pool[tmp]
        #bigger lookup
        if a.size > tmp.size:
            rest = a.size
            ptr = a_val
            out = []
            ptr_index = 0
            while rest:
                v = self.find_mem_by_addr(ptr)
                if v is None:
                    #raise ValueError("cannot find %s in mem"%str(ptr))
                    val = ExprMem(ptr, 8)
                    v = val
                    diff_size = 8
                elif rest >= v.size:
                    val = self.pool[v]
                    diff_size = v.size
                else:
                    diff_size = rest
                    val = self.pool[v][0:diff_size]
                val = (val, ptr_index, ptr_index + diff_size)
                out.append(val)
                ptr_index += diff_size
                rest -= diff_size
                ptr = expr_simp(
                    self.eval_expr(
                        ExprOp('+', ptr, ExprInt(uint32(v.size / 8))),
                        eval_cache))
            e = expr_simp(ExprCompose(out))
            return e
        #part lookup
        tmp = expr_simp(ExprSlice(self.pool[tmp], 0, a.size))
        return tmp