Пример #1
0
    def _set(self, dst, src):
        """
        Special cases:
        * if dst is an ExprSlice, expand it to affect the full Expression
        * if dst already known, sources are merged
        """
        if dst.size != src.size:
            raise RuntimeError("sanitycheck: args must have same size! %s" %
                               ([(str(arg), arg.size) for arg in [dst, src]]))

        if isinstance(dst, m2_expr.ExprSlice):
            # Complete the source with missing slice parts
            new_dst = dst.arg
            rest = [(m2_expr.ExprSlice(dst.arg, r[0], r[1]), r[0], r[1])
                    for r in dst.slice_rest()]
            all_a = [(src, dst.start, dst.stop)] + rest
            all_a.sort(key=lambda x: x[1])
            args = [expr for (expr, _, _) in all_a]
            new_src = m2_expr.ExprCompose(*args)
        else:
            new_dst, new_src = dst, src

        if new_dst in self._assigns and isinstance(new_src,
                                                   m2_expr.ExprCompose):
            if not isinstance(self[new_dst], m2_expr.ExprCompose):
                # prev_RAX = 0x1122334455667788
                # input_RAX[0:8] = 0x89
                # final_RAX -> ? (assignment are in parallel)
                raise RuntimeError("Concurent access on same bit not allowed")

            # Consider slice grouping
            expr_list = [(new_dst, new_src), (new_dst, self[new_dst])]
            # Find collision
            e_colision = reduce(lambda x, y: x.union(y),
                                (self.get_modified_slice(dst, src)
                                 for (dst, src) in expr_list), set())

            # Sort interval collision
            known_intervals = sorted([(x[1], x[2]) for x in e_colision])

            for i, (_, stop) in enumerate(known_intervals[:-1]):
                if stop > known_intervals[i + 1][0]:
                    raise RuntimeError(
                        "Concurent access on same bit not allowed")

            # Fill with missing data
            missing_i = get_missing_interval(known_intervals, 0, new_dst.size)
            remaining = ((m2_expr.ExprSlice(new_dst, *interval), interval[0],
                          interval[1]) for interval in missing_i)

            # Build the merging expression
            args = list(e_colision.union(remaining))
            args.sort(key=lambda x: x[1])
            starts = [start for (_, start, _) in args]
            assert len(set(starts)) == len(starts)
            args = [expr for (expr, _, _) in args]
            new_src = m2_expr.ExprCompose(*args)

        self._assigns[new_dst] = new_src
Пример #2
0
    def merge_multi_affect(self, affect_list):
        """
        If multiple affection to a same ExprId are present in @affect_list,
        merge them (in place).
        For instance, XCGH AH, AL semantic is
        [
            RAX = {RAX[0:8],0,8, RAX[0:8],8,16, RAX[16:64],16,64}
            RAX = {RAX[8:16],0,8, RAX[8:64],8,64}
        ]
        This function will update @affect_list to replace previous ExprAff by
        [
            RAX = {RAX[8:16],0,8, RAX[0:8],8,16, RAX[16:64],16,64}
        ]
        """

        # Extract side effect
        effect = {}
        for expr in affect_list:
            effect[expr.dst] = effect.get(expr.dst, []) + [expr]

        # Find candidates
        for dst, expr_list in effect.items():
            if len(expr_list) <= 1:
                continue

            # Only treat ExprCompose list
            if any(map(lambda e: not(isinstance(e.src, m2_expr.ExprCompose)),
                       expr_list)):
                continue

            # Find collision
            e_colision = reduce(lambda x, y: x.union(y),
                                (e.get_modified_slice() for e in expr_list),
                                set())
            # Sort interval collision
            known_intervals = sorted([(x[1], x[2]) for x in e_colision])

            # Fill with missing data
            missing_i = get_missing_interval(known_intervals, 0, dst.size)

            remaining = ((m2_expr.ExprSlice(dst, *interval),
                          interval[0],
                          interval[1])
                         for interval in missing_i)

            # Build the merging expression
            slices = sorted(e_colision.union(remaining), key=lambda x: x[1])
            final_dst = m2_expr.ExprCompose(slices)

            # Remove unused expression
            for expr in expr_list:
                affect_list.remove(expr)

            # Add the merged one
            affect_list.append(m2_expr.ExprAff(dst, final_dst))
Пример #3
0
    def apply_expr_on_state_visit_cache(self, expr, state, cache, level=0):
        """
        Deep First evaluate nodes:
            1. evaluate node's sons
            2. simplify
        """

        expr = self.expr_simp(expr)

        #print '\t'*level, "Eval:", expr
        if expr in cache:
            ret = cache[expr]
            #print "In cache!", ret
        elif expr.is_int():
            return expr
        elif expr.is_id():
            if isinstance(expr.name, asmblock.AsmLabel) and expr.name.offset is not None:
                ret = m2_expr.ExprInt(expr.name.offset, expr.size)
            else:
                ret = state.get(expr, expr)
        elif expr.is_mem():
            ptr = self.apply_expr_on_state_visit_cache(expr.arg, state, cache, level+1)
            ret = m2_expr.ExprMem(ptr, expr.size)
            ret = self.get_mem_state(ret)
            assert expr.size == ret.size
        elif expr.is_cond():
            cond = self.apply_expr_on_state_visit_cache(expr.cond, state, cache, level+1)
            src1 = self.apply_expr_on_state_visit_cache(expr.src1, state, cache, level+1)
            src2 = self.apply_expr_on_state_visit_cache(expr.src2, state, cache, level+1)
            ret = m2_expr.ExprCond(cond, src1, src2)
        elif expr.is_slice():
            arg = self.apply_expr_on_state_visit_cache(expr.arg, state, cache, level+1)
            ret = m2_expr.ExprSlice(arg, expr.start, expr.stop)
        elif expr.is_op():
            args = []
            for oarg in expr.args:
                arg = self.apply_expr_on_state_visit_cache(oarg, state, cache, level+1)
                assert oarg.size == arg.size
                args.append(arg)
            ret = m2_expr.ExprOp(expr.op, *args)
        elif expr.is_compose():
            args = []
            for arg in expr.args:
                args.append(self.apply_expr_on_state_visit_cache(arg, state, cache, level+1))
            ret = m2_expr.ExprCompose(*args)
        else:
            raise TypeError("Unknown expr type")
        #print '\t'*level, "Result", ret
        ret = self.expr_simp(ret)
        #print '\t'*level, "Result simpl", ret

        assert expr.size == ret.size
        cache[expr] = ret
        return ret
Пример #4
0
    def eval_ExprMem(self, e, eval_cache=None):
        if eval_cache is None:
            eval_cache = {}
        a_val = self.expr_simp(self.eval_expr(e.arg, eval_cache))
        if a_val != e.arg:
            a = self.expr_simp(m2_expr.ExprMem(a_val, size=e.size))
        else:
            a = e
        if a in self.symbols:
            return self.symbols[a]
        tmp = None
        # test if mem lookup is known
        if a_val in self.symbols.symbols_mem:
            tmp = self.symbols.symbols_mem[a_val][0]
        if tmp is None:

            v = self.find_mem_by_addr(a_val)
            if not v:
                out = []
                ov = self.get_mem_overlapping(a, eval_cache)
                off_base = 0
                ov.sort()
                # ov.reverse()
                for off, x in ov:
                    # off_base = off * 8
                    # x_size = self.symbols[x].size
                    if off >= 0:
                        m = min(a.size - off * 8, x.size)
                        ee = m2_expr.ExprSlice(self.symbols[x], 0, m)
                        ee = self.expr_simp(ee)
                        out.append((ee, off_base, off_base + m))
                        off_base += m
                    else:
                        m = min(a.size - off * 8, x.size)
                        ee = m2_expr.ExprSlice(self.symbols[x], -off * 8, m)
                        ff = self.expr_simp(ee)
                        new_off_base = off_base + m + off * 8
                        out.append((ff, off_base, new_off_base))
                        off_base = new_off_base
                if out:
                    missing_slice = self.rest_slice(out, 0, a.size)
                    for sa, sb in missing_slice:
                        ptr = self.expr_simp(
                            a_val + m2_expr.ExprInt_from(a_val, sa / 8))
                        mm = m2_expr.ExprMem(ptr, size=sb - sa)
                        mm.is_term = True
                        mm.is_simp = True
                        out.append((mm, sa, sb))
                    out.sort(key=lambda x: x[1])
                    # for e, sa, sb in out:
                    #    print str(e), sa, sb
                    ee = m2_expr.ExprSlice(m2_expr.ExprCompose(out), 0, a.size)
                    ee = self.expr_simp(ee)
                    return ee
            if self.func_read and isinstance(a.arg, m2_expr.ExprInt):
                return self.func_read(a)
            else:
                # XXX hack test
                a.is_term = True
                return a
        # bigger lookup
        if a.size > tmp.size:
            rest = a.size
            ptr = a_val
            out = []
            ptr_index = 0
            while rest:
                v = self.find_mem_by_addr(ptr)
                if v is None:
                    # raise ValueError("cannot find %s in mem"%str(ptr))
                    val = m2_expr.ExprMem(ptr, 8)
                    v = val
                    diff_size = 8
                elif rest >= v.size:
                    val = self.symbols[v]
                    diff_size = v.size
                else:
                    diff_size = rest
                    val = self.symbols[v][0:diff_size]
                val = (val, ptr_index, ptr_index + diff_size)
                out.append(val)
                ptr_index += diff_size
                rest -= diff_size
                ptr = self.expr_simp(
                    self.eval_expr(
                        m2_expr.ExprOp('+', ptr,
                                       m2_expr.ExprInt_from(ptr, v.size / 8)),
                        eval_cache))
            e = self.expr_simp(m2_expr.ExprCompose(out))
            return e
        # part lookup
        tmp = self.expr_simp(m2_expr.ExprSlice(self.symbols[tmp], 0, a.size))
        return tmp
Пример #5
0
    def get_mem_state(self, expr):
        """
        Evaluate the @expr memory in the current state using @cache
        @expr: the memory key
        """
        ptr, size = expr.arg, expr.size
        ret = self.find_mem_by_addr(ptr)
        if not ret:
            out = []
            overlaps = self.get_mem_overlapping(expr)
            off_base = 0
            for off, mem in overlaps:
                if off >= 0:
                    new_size = min(size - off * 8, mem.size)
                    tmp = self.expr_simp(self.symbols[mem][0:new_size])
                    out.append((tmp, off_base, off_base + new_size))
                    off_base += new_size
                else:
                    new_size = min(size - off * 8, mem.size)
                    tmp = self.expr_simp(self.symbols[mem][-off * 8:new_size])
                    new_off_base = off_base + new_size + off * 8
                    out.append((tmp, off_base, new_off_base))
                    off_base = new_off_base
            if out:
                missing_slice = self.rest_slice(out, 0, size)
                for slice_start, slice_stop in missing_slice:
                    ptr = self.expr_simp(
                        ptr + m2_expr.ExprInt(slice_start / 8, ptr.size))
                    mem = m2_expr.ExprMem(ptr, slice_stop - slice_start)
                    out.append((mem, slice_start, slice_stop))
                out.sort(key=lambda x: x[1])
                args = [expr for (expr, _, _) in out]
                tmp = m2_expr.ExprSlice(m2_expr.ExprCompose(*args), 0, size)
                tmp = self.expr_simp(tmp)
                return tmp

            if self.func_read and isinstance(ptr, m2_expr.ExprInt):
                return self.func_read(expr)
            else:
                return expr
        # bigger lookup
        if size > ret.size:
            rest = size
            ptr = ptr
            out = []
            ptr_index = 0
            while rest:
                mem = self.find_mem_by_addr(ptr)
                if mem is None:
                    value = m2_expr.ExprMem(ptr, 8)
                    mem = value
                    diff_size = 8
                elif rest >= mem.size:
                    value = self.symbols[mem]
                    diff_size = mem.size
                else:
                    diff_size = rest
                    value = self.symbols[mem][0:diff_size]
                out.append((value, ptr_index, ptr_index + diff_size))
                ptr_index += diff_size
                rest -= diff_size
                ptr = self.expr_simp(ptr +
                                     m2_expr.ExprInt(mem.size / 8, ptr.size))
            out.sort(key=lambda x: x[1])
            args = [expr for (expr, _, _) in out]
            ret = self.expr_simp(m2_expr.ExprCompose(*args))
            return ret
        # part lookup
        ret = self.expr_simp(self.symbols[ret][:size])
        return ret