Exemple #1
0
    def test_ClassDef(self):
        from miasm2.expression.expression import ExprInt32, ExprId, ExprMem, ExprCompose
        from miasm2.arch.x86.sem import ir_x86_32
        from miasm2.ir.symbexec import symbexec

        addrX = ExprInt32(-1)
        addr0 = ExprInt32(0)
        addr1 = ExprInt32(1)
        addr8 = ExprInt32(8)
        addr9 = ExprInt32(9)
        addr20 = ExprInt32(20)
        addr40 = ExprInt32(40)
        addr50 = ExprInt32(50)
        mem0 = ExprMem(addr0)
        mem1 = ExprMem(addr1)
        mem8 = ExprMem(addr8)
        mem9 = ExprMem(addr9)
        mem20 = ExprMem(addr20)
        mem40v = ExprMem(addr40, 8)
        mem40w = ExprMem(addr40, 16)
        mem50v = ExprMem(addr50, 8)
        mem50w = ExprMem(addr50, 16)
        id_x = ExprId('x')
        id_y = ExprId('y', 8)
        id_a = ExprId('a')
        id_eax = ExprId('eax_init')

        e = symbexec(
            ir_x86_32(), {
                mem0: id_x,
                mem1: id_y,
                mem9: id_x,
                mem40w: id_x,
                mem50v: id_y,
                id_a: addr0,
                id_eax: addr0
            })
        self.assertEqual(e.find_mem_by_addr(addr0), mem0)
        self.assertEqual(e.find_mem_by_addr(addrX), None)
        self.assertEqual(e.eval_ExprMem(ExprMem(addr1 - addr1)), id_x)
        self.assertEqual(e.eval_ExprMem(ExprMem(addr1, 8)), id_y)
        self.assertEqual(
            e.eval_ExprMem(ExprMem(addr1 + addr1)),
            ExprCompose([(id_x[16:32], 0, 16),
                         (ExprMem(ExprInt32(4), 16), 16, 32)]))
        self.assertEqual(
            e.eval_ExprMem(mem8),
            ExprCompose([(id_x[0:24], 0, 24),
                         (ExprMem(ExprInt32(11), 8), 24, 32)]))
        self.assertEqual(e.eval_ExprMem(mem40v), id_x[:8])
        self.assertEqual(
            e.eval_ExprMem(mem50w),
            ExprCompose([(id_y, 0, 8), (ExprMem(ExprInt32(51), 8), 8, 16)]))
        self.assertEqual(e.eval_ExprMem(mem20), mem20)
        e.func_read = lambda x: x
        self.assertEqual(e.eval_ExprMem(mem20), mem20)
        self.assertEqual(set(e.modified()), set(e.symbols))
        self.assertRaises(KeyError, e.symbols.__getitem__,
                          ExprMem(ExprInt32(100)))
Exemple #2
0
def simp_subwc_cf(_, expr):
    """SUBWC_CF(A, B, SUB_CF(C, D)) => SUB_CF({A, C}, {B, D})"""
    if not expr.is_op('FLAG_SUBWC_CF'):
        return expr
    op3 = expr.args[2]
    if not op3.is_op("FLAG_SUB_CF"):
        return expr

    op1 = ExprCompose(expr.args[0], op3.args[0])
    op2 = ExprCompose(expr.args[1], op3.args[1])

    return ExprOp("FLAG_SUB_CF", op1, op2)
Exemple #3
0
def simp_sign_subwc_cf(expr_s, expr):
    # SIGN_SUBWC(A, B, SUB_CF(C, D)) => SIGN_SUB({A, C}, {B, D})
    if not expr.is_op('FLAG_SIGN_SUBWC'):
        return expr
    op3 = expr.args[2]
    if not op3.is_op("FLAG_SUB_CF"):
        return expr

    op1 = ExprCompose(expr.args[0], op3.args[0])
    op2 = ExprCompose(expr.args[1], op3.args[1])

    return ExprOp("FLAG_SIGN_SUB", op1, op2)
def simp_compose(e_s, expr):
    "Commons simplification on ExprCompose"
    args = merge_sliceto_slice(expr)
    out = []
    # compose of compose
    for arg in args:
        if arg.is_compose():
            out += arg.args
        else:
            out.append(arg)
    args = out
    # Compose(a) with a.size = compose.size => a
    if len(args) == 1 and args[0].size == expr.size:
        return args[0]

    # {(X[z:], 0, X.size-z), (0, X.size-z, X.size)} => (X >> z)
    if len(args) == 2 and args[1].is_int(0):
        if (args[0].is_slice() and
            args[0].stop == args[0].arg.size and
            args[0].size + args[1].size == args[0].arg.size):
            new_expr = args[0].arg >> ExprInt(args[0].start, args[0].arg.size)
            return new_expr

    # {@X[base + i] 0 X, @Y[base + i + X] X (X + Y)} => @(X+Y)[base + i]
    for i, arg in enumerate(args[:-1]):
        nxt = args[i + 1]
        if arg.is_mem() and nxt.is_mem():
            gap = e_s(nxt.arg - arg.arg)
            if gap.is_int() and arg.size % 8 == 0 and int(gap) == arg.size / 8:
                args = args[:i] + [ExprMem(arg.arg,
                                          arg.size + nxt.size)] + args[i + 2:]
                return ExprCompose(*args)

    # {a, x?b:d, x?c:e, f} => x?{a, b, c, f}:{a, d, e, f}
    conds = set(arg.cond for arg in expr.args if arg.is_cond())
    if len(conds) == 1:
        cond = list(conds)[0]
        args1, args2 = [], []
        for arg in expr.args:
            if arg.is_cond():
                args1.append(arg.src1)
                args2.append(arg.src2)
            else:
                args1.append(arg)
                args2.append(arg)
        arg1 = e_s(ExprCompose(*args1))
        arg2 = e_s(ExprCompose(*args2))
        return ExprCond(cond, arg1, arg2)
    return ExprCompose(*args)
Exemple #5
0
    def _func_read(self, expr_mem):
        if not expr_mem.ptr.is_int():
            return expr_mem
        dst_addr = int(expr_mem.ptr)

        if not self.dse_memory_range:
            # Trivial case (optimization)
            return super(ESETrackModif, self)._func_read(expr_mem)

        # Split access in atomic accesses
        out = []
        for addr in xrange(dst_addr, dst_addr + (expr_mem.size / 8)):
            if addr in self.dse_memory_range:
                # Symbolize memory access
                out.append(self.dse_memory_to_expr(addr))
            else:
                # Get concrete value
                atomic_access = ExprMem(ExprInt(addr, expr_mem.ptr.size), 8)
                out.append(super(ESETrackModif, self)._func_read(atomic_access))

        if len(out) == 1:
            # Trivial case (optimization)
            return out[0]

        # Simplify for constant merging (ex: {ExprInt(1, 8), ExprInt(2, 8)})
        return self.expr_simp(ExprCompose(*out))
def simp_ext(_, expr):
    if expr.op.startswith('zeroExt_'):
        arg = expr.args[0]
        if expr.size == arg.size:
            return arg
        return ExprCompose(arg, ExprInt(0, expr.size - arg.size))

    if expr.op.startswith("signExt_"):
        arg = expr.args[0]
        add_size = expr.size - arg.size
        new_expr = ExprCompose(
            arg,
            ExprCond(arg.msb(), ExprInt(size2mask(add_size), add_size),
                     ExprInt(0, add_size)))
        return new_expr
    return expr
Exemple #7
0
 def eval_exprcompose(self, expr, **kwargs):
     """[DEV]: Evaluate an ExprCompose using the current state"""
     args = []
     for arg in expr.args:
         args.append(self.eval_expr_visitor(arg, **kwargs))
     ret = ExprCompose(*args)
     return ret
Exemple #8
0
def simp_cond(_, expr):
    """
    Common simplifications on ExprCond.
    Eval exprcond src1/src2 with satifiable/unsatisfiable condition propagation
    """
    if (not expr.cond.is_int()) and expr.cond.size == 1:
        src1 = expr.src1.replace_expr({expr.cond: ExprInt(1, 1)})
        src2 = expr.src2.replace_expr({expr.cond: ExprInt(0, 1)})
        if src1 != expr.src1 or src2 != expr.src2:
            return ExprCond(expr.cond, src1, src2)

    # -A ? B:C => A ? B:C
    if expr.cond.is_op('-') and len(expr.cond.args) == 1:
        expr = ExprCond(expr.cond.args[0], expr.src1, expr.src2)
    # a?x:x
    elif expr.src1 == expr.src2:
        expr = expr.src1
    # int ? A:B => A or B
    elif expr.cond.is_int():
        if expr.cond.arg == 0:
            expr = expr.src2
        else:
            expr = expr.src1
    # a?(a?b:c):x => a?b:x
    elif expr.src1.is_cond() and expr.cond == expr.src1.cond:
        expr = ExprCond(expr.cond, expr.src1.src1, expr.src2)
    # a?x:(a?b:c) => a?x:c
    elif expr.src2.is_cond() and expr.cond == expr.src2.cond:
        expr = ExprCond(expr.cond, expr.src1, expr.src2.src2)
    # a|int ? b:c => b with int != 0
    elif (expr.cond.is_op('|') and
          expr.cond.args[1].is_int() and
          expr.cond.args[1].arg != 0):
        return expr.src1

    # (C?int1:int2)?(A:B) =>
    elif (expr.cond.is_cond() and
          expr.cond.src1.is_int() and
          expr.cond.src2.is_int()):
        int1 = expr.cond.src1.arg.arg
        int2 = expr.cond.src2.arg.arg
        if int1 and int2:
            expr = expr.src1
        elif int1 == 0 and int2 == 0:
            expr = expr.src2
        elif int1 == 0 and int2:
            expr = ExprCond(expr.cond.cond, expr.src2, expr.src1)
        elif int1 and int2 == 0:
            expr = ExprCond(expr.cond.cond, expr.src1, expr.src2)

    elif expr.cond.is_compose():
        # {0, X, 0}?(A:B) => X?(A:B)
        args = [arg for arg in expr.cond.args if not arg.is_int(0)]
        if len(args) == 1:
            arg = args.pop()
            return ExprCond(arg, expr.src1, expr.src2)
        elif len(args) < len(expr.cond.args):
            return ExprCond(ExprCompose(*args), expr.src1, expr.src2)
    return expr
Exemple #9
0
def rev16(ir, instr, arg1, arg2):
    out = []
    for i in xrange(0, arg2.size / 8):
        index = (i & ~1) + (1 - (i & 1))
        out.append(arg2[index * 8:(index + 1) * 8])
    e = []
    result = ExprCompose(*out)
    e.append(ExprAssign(arg1, result))
    return e, []
Exemple #10
0
 def check(self):
     regs = self.dse.ir_arch.arch.regs
     value = self.dse.eval_expr(regs.EDX)
     # The expected value should contains '<<', showing it has been in the
     # corresponding generated label
     expected = ExprOp(
         '<<', regs.EDX,
         ExprCompose(regs.ECX[0:8], ExprInt(0x0, 24)) & ExprInt(0x1F, 32))
     assert value == expected
Exemple #11
0
def rev(ir, instr, arg1, arg2):
    out = []
    for i in xrange(0, arg2.size, 8):
        out.append(arg2[i:i + 8])
    out.reverse()
    e = []
    result = ExprCompose(*out)
    e.append(ExprAssign(arg1, result))
    return e, []
Exemple #12
0
 def eval_exprcompose(self, expr, **kwargs):
     """[DEV]: Evaluate an ExprCompose using the current state"""
     args = []
     for arg in expr.args:
         arg = self.eval_expr_visitor(arg, **kwargs)
         if arg.is_id(TOPSTR):
             return exprid_top(expr)
         args.append(arg)
     ret = ExprCompose(*args)
     return ret
    def categorize(self, node, lvl=0, **kwargs):
        """Recursively apply rules to @node

        @node: ExprNode to analyze
        @lvl: actual recusion level
        """

        expr = node.expr
        log_reduce.debug("\t" * lvl + "Reduce...: %s", node.expr)
        if isinstance(expr, ExprId):
            node = ExprNodeId(expr)
        elif isinstance(expr, ExprInt):
            node = ExprNodeInt(expr)
        elif isinstance(expr, ExprLoc):
            node = ExprNodeLoc(expr)
        elif isinstance(expr, ExprMem):
            arg = self.categorize(node.arg, lvl=lvl + 1, **kwargs)
            node = ExprNodeMem(ExprMem(arg.expr, expr.size))
            node.arg = arg
        elif isinstance(expr, ExprSlice):
            arg = self.categorize(node.arg, lvl=lvl + 1, **kwargs)
            node = ExprNodeSlice(ExprSlice(arg.expr, expr.start, expr.stop))
            node.arg = arg
        elif isinstance(expr, ExprOp):
            new_args = []
            for arg in node.args:
                new_a = self.categorize(arg, lvl=lvl + 1, **kwargs)
                assert new_a.expr.size == arg.expr.size
                new_args.append(new_a)
            node = ExprNodeOp(ExprOp(expr.op, *[x.expr for x in new_args]))
            node.args = new_args
            expr = node.expr
        elif isinstance(expr, ExprCompose):
            new_args = []
            new_expr_args = []
            for arg in node.args:
                arg = self.categorize(arg, lvl=lvl + 1, **kwargs)
                new_args.append(arg)
                new_expr_args.append(arg.expr)
            new_expr = ExprCompose(*new_expr_args)
            node = ExprNodeCompose(new_expr)
            node.args = new_args
        elif isinstance(expr, ExprCond):
            cond = self.categorize(node.cond, lvl=lvl + 1, **kwargs)
            src1 = self.categorize(node.src1, lvl=lvl + 1, **kwargs)
            src2 = self.categorize(node.src2, lvl=lvl + 1, **kwargs)
            node = ExprNodeCond(ExprCond(cond.expr, src1.expr, src2.expr))
            node.cond, node.src1, node.src2 = cond, src1, src2
        else:
            raise TypeError("Unknown Expr Type %r", type(expr))

        node.info = self.apply_rules(node, lvl=lvl, **kwargs)
        log_reduce.debug("\t" * lvl + "Reduce result: %s %r", node.expr,
                         node.info)
        return node
Exemple #14
0
def casp(ir, instr, arg1, arg2, arg3):
    # XXX TODO: memory barrier
    e = []
    if arg1.size == 32:
        regs = gpregs32_expr
    else:
        regs = gpregs64_expr
    index1 = regs.index(arg1)
    index2 = regs.index(arg2)

    # TODO endianness
    comp_value = ExprCompose(regs[index1], regs[index1 + 1])
    new_value = ExprCompose(regs[index2], regs[index2 + 1])
    assert arg3.is_op('preinc')
    ptr = arg3.args[0]
    data = ExprMem(ptr, comp_value.size)

    loc_store = ExprLoc(ir.loc_db.add_location(), ir.IRDst.size)
    loc_do = ExprLoc(ir.loc_db.add_location(), ir.IRDst.size)
    loc_next = ExprLoc(ir.get_next_loc_key(instr), ir.IRDst.size)

    e.append(
        ExprAssign(
            ir.IRDst,
            ExprCond(ExprOp("FLAG_EQ_CMP", data, comp_value), loc_do,
                     loc_store)))

    e_store = []
    e_store.append(ExprAssign(data, new_value))
    e_store.append(ExprAssign(ir.IRDst, loc_do))
    blk_store = IRBlock(loc_store.loc_key, [AssignBlock(e_store, instr)])

    e_do = []
    e_do.append(ExprAssign(regs[index1], data[:data.size / 2]))
    e_do.append(ExprAssign(regs[index1 + 1], data[data.size / 2:]))
    e_do.append(ExprAssign(ir.IRDst, loc_next))
    blk_do = IRBlock(loc_do.loc_key, [AssignBlock(e_do, instr)])

    return e, [blk_store, blk_do]
Exemple #15
0
def mrs(ir, insr, arg1, arg2, arg3, arg4, arg5):
    e = []
    if arg2.is_int(3) and arg3.is_id("c4") and arg4.is_id(
            "c2") and arg5.is_int(0):
        out = []
        out.append(ExprInt(0x0, 28))
        out.append(of)
        out.append(cf)
        out.append(zf)
        out.append(nf)
        e.append(ExprAssign(arg1, ExprCompose(*out).zeroExtend(arg1.size)))
    else:
        raise NotImplementedError("MRS not implemented")
    return e, []
Exemple #16
0
 def read(self, ptr, size):
     """
     Return the value associated with the Expr at address @ptr
     @ptr: Expr representing the memory address
     @size: memory size (in bits), byte aligned
     """
     assert size % 8 == 0
     base, offset = get_expr_base_offset(ptr)
     memarray = self.base_to_memarray.get(base, None)
     if memarray is not None:
         mems = memarray.read(offset, size)
         ret = ExprCompose(*mems)
     else:
         ret = ExprMem(ptr, size)
     return ret
Exemple #17
0
    def mem_read(self, expr):
        """
        [DEV]: Override to modify the effective memory reads

        Read symbolic value at ExprMem @expr
        @expr: ExprMem
        """

        parts = self._resolve_mem_parts(expr)

        out = []
        for known, part in parts:
            if not known and part.is_mem() and self.func_read is not None:
                ret = self.func_read(part)
            else:
                ret = part

            out.append(ret)
        ret = self.expr_simp(ExprCompose(*out))

        assert ret.size == expr.size
        return ret
Exemple #18
0
    def test_ClassDef(self):
        from miasm2.expression.expression import ExprInt, ExprId, ExprMem, \
            ExprCompose, ExprAff
        from miasm2.arch.x86.sem import ir_x86_32
        from miasm2.ir.symbexec import SymbolicExecutionEngine
        from miasm2.ir.ir import AssignBlock

        addrX = ExprInt(-1, 32)
        addr0 = ExprInt(0, 32)
        addr1 = ExprInt(1, 32)
        addr8 = ExprInt(8, 32)
        addr9 = ExprInt(9, 32)
        addr20 = ExprInt(20, 32)
        addr40 = ExprInt(40, 32)
        addr50 = ExprInt(50, 32)
        mem0 = ExprMem(addr0)
        mem1 = ExprMem(addr1, 8)
        mem8 = ExprMem(addr8)
        mem9 = ExprMem(addr9)
        mem20 = ExprMem(addr20)
        mem40v = ExprMem(addr40, 8)
        mem40w = ExprMem(addr40, 16)
        mem50v = ExprMem(addr50, 8)
        mem50w = ExprMem(addr50, 16)
        id_x = ExprId('x')
        id_y = ExprId('y', 8)
        id_a = ExprId('a')
        id_eax = ExprId('eax_init')

        e = SymbolicExecutionEngine(
            ir_x86_32(), {
                mem0: id_x,
                mem1: id_y,
                mem9: id_x,
                mem40w: id_x[:16],
                mem50v: id_y,
                id_a: addr0,
                id_eax: addr0
            })
        self.assertEqual(e.find_mem_by_addr(addr0), mem0)
        self.assertEqual(e.find_mem_by_addr(addrX), None)
        self.assertEqual(e.eval_expr(ExprMem(addr1 - addr1)), id_x)
        self.assertEqual(e.eval_expr(ExprMem(addr1, 8)), id_y)
        self.assertEqual(e.eval_expr(ExprMem(addr1 + addr1)),
                         ExprCompose(id_x[16:32], ExprMem(ExprInt(4, 32), 16)))
        self.assertEqual(e.eval_expr(mem8),
                         ExprCompose(id_x[0:24], ExprMem(ExprInt(11, 32), 8)))
        self.assertEqual(e.eval_expr(mem40v), id_x[:8])
        self.assertEqual(e.eval_expr(mem50w),
                         ExprCompose(id_y, ExprMem(ExprInt(51, 32), 8)))
        self.assertEqual(e.eval_expr(mem20), mem20)
        e.func_read = lambda x: x
        self.assertEqual(e.eval_expr(mem20), mem20)
        self.assertEqual(set(e.modified()), set(e.symbols))
        self.assertRaises(KeyError, e.symbols.__getitem__,
                          ExprMem(ExprInt(100, 32)))
        self.assertEqual(e.apply_expr(id_eax), addr0)
        self.assertEqual(e.apply_expr(ExprAff(id_eax, addr9)), addr9)
        self.assertEqual(e.apply_expr(id_eax), addr9)

        # apply_change / eval_ir / apply_expr

        ## x = a (with a = 0x0)
        assignblk = AssignBlock({id_x: id_a})
        e.eval_ir(assignblk)
        self.assertEqual(e.apply_expr(id_x), addr0)

        ## x = a (without replacing 'a' with 0x0)
        e.apply_change(id_x, id_a)
        self.assertEqual(e.apply_expr(id_x), id_a)

        ## x = a (with a = 0x0)
        self.assertEqual(e.apply_expr(assignblk.dst2ExprAff(id_x)), addr0)
        self.assertEqual(e.apply_expr(id_x), addr0)
Exemple #19
0
"""

This example demonstrates the recovering of possible C types for an arbitrary
variable in an assembly code (the types are inferred from the function
argument types). It also displays the C code used to access this variable.

Input:
* definitions of the C types that can be used by the code
* layout of structures (packed/not packed)
* prototype of the analyzed function

Algorithm:
The DepGraph of the target variable is computed, which gives possible
expressions for this variable. For each DepGraph solution, if the expression
depends on typed arguments, the code infers the variable type and displays the C
code to access this variable.


Here be dragons:
For the moment, Miasm can infer C types (and generate C) for simple expressions.
To summarize, Miasm only supports accesses that do not involve arithmetic or
conditional expressions such as the following:
* var1.field
* var1[12][4]
* *(var1.field->tab[4])

Unsupported forms:
* var1 + var2
* var1[var2+4]
* var1?var2->field:6
Exemple #20
0
expr << (expr_int | expr_id | expr_loc | expr_slice | expr_mem | expr_cond | \
         expr_compose | expr_op | expr_aff)

def parse_loc_key(t):
    assert len(t) == 2
    loc_key, size = LocKey(t[0]), t[1]
    return ExprLoc(loc_key, size)

expr_int.setParseAction(lambda t: ExprInt(*t))
expr_id.setParseAction(lambda t: ExprId(*t))
expr_loc.setParseAction(parse_loc_key)
expr_slice.setParseAction(lambda t: ExprSlice(*t))
expr_mem.setParseAction(lambda t: ExprMem(*t))
expr_cond.setParseAction(lambda t: ExprCond(*t))
expr_compose.setParseAction(lambda t: ExprCompose(*t))
expr_op.setParseAction(lambda t: ExprOp(*t))
expr_aff.setParseAction(lambda t: ExprAssign(*t))


def str_to_expr(str_in):
    """Parse the @str_in and return the corresponoding Expression
    @str_in: repr string of an Expression"""

    try:
        value = expr.parseString(str_in)
    except:
        raise RuntimeError("Cannot parse expression %s" % str_in)
    assert len(value) == 1
    return value[0]
def simp_slice(e_s, expr):
    "Slice optimization"

    # slice(A, 0, a.size) => A
    if expr.start == 0 and expr.stop == expr.arg.size:
        return expr.arg
    # Slice(int) => int
    if expr.arg.is_int():
        total_bit = expr.stop - expr.start
        mask = (1 << (expr.stop - expr.start)) - 1
        return ExprInt(int((expr.arg.arg >> expr.start) & mask), total_bit)
    # Slice(Slice(A, x), y) => Slice(A, z)
    if expr.arg.is_slice():
        if expr.stop - expr.start > expr.arg.stop - expr.arg.start:
            raise ValueError('slice in slice: getting more val', str(expr))

        return ExprSlice(expr.arg.arg, expr.start + expr.arg.start,
                         expr.start + expr.arg.start + (expr.stop - expr.start))
    if expr.arg.is_compose():
        # Slice(Compose(A), x) => Slice(A, y)
        for index, arg in expr.arg.iter_args():
            if index <= expr.start and index+arg.size >= expr.stop:
                return arg[expr.start - index:expr.stop - index]
        # Slice(Compose(A, B, C), x) => Compose(A, B, C) with truncated A/B/C
        out = []
        for index, arg in expr.arg.iter_args():
            # arg is before slice start
            if expr.start >= index + arg.size:
                continue
            # arg is after slice stop
            elif expr.stop <= index:
                continue
            # arg is fully included in slice
            elif expr.start <= index and index + arg.size <= expr.stop:
                out.append(arg)
                continue
            # arg is truncated at start
            if expr.start > index:
                slice_start = expr.start - index
            else:
                # arg is not truncated at start
                slice_start = 0
            # a is truncated at stop
            if expr.stop < index + arg.size:
                slice_stop = arg.size + expr.stop - (index + arg.size) - slice_start
            else:
                slice_stop = arg.size
            out.append(arg[slice_start:slice_stop])

        return ExprCompose(*out)

    # ExprMem(x, size)[:A] => ExprMem(x, a)
    # XXXX todo hum, is it safe?
    if (expr.arg.is_mem() and
          expr.start == 0 and
          expr.arg.size > expr.stop and expr.stop % 8 == 0):
        return ExprMem(expr.arg.arg, size=expr.stop)
    # distributivity of slice and &
    # (a & int)[x:y] => 0 if int[x:y] == 0
    if expr.arg.is_op("&") and expr.arg.args[-1].is_int():
        tmp = e_s.expr_simp_wrapper(expr.arg.args[-1][expr.start:expr.stop])
        if tmp.is_int(0):
            return tmp
    # distributivity of slice and exprcond
    # (a?int1:int2)[x:y] => (a?int1[x:y]:int2[x:y])
    if expr.arg.is_cond() and expr.arg.src1.is_int() and expr.arg.src2.is_int():
        src1 = expr.arg.src1[expr.start:expr.stop]
        src2 = expr.arg.src2[expr.start:expr.stop]
        return ExprCond(expr.arg.cond, src1, src2)

    # (a * int)[0:y] => (a[0:y] * int[0:y])
    if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int():
        args = [e_s.expr_simp_wrapper(a[expr.start:expr.stop]) for a in expr.arg.args]
        return ExprOp(expr.arg.op, *args)

    # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size
    # (a << int)[x:y] => a[x-int:y-int] with x-int >= 0
    if (expr.arg.is_op() and expr.arg.op in [">>", "<<"] and
          expr.arg.args[1].is_int()):
        arg, shift = expr.arg.args
        shift = int(shift)
        if expr.arg.op == ">>":
            if shift + expr.stop <= arg.size:
                return arg[expr.start + shift:expr.stop + shift]
        elif expr.arg.op == "<<":
            if expr.start - shift >= 0:
                return arg[expr.start - shift:expr.stop - shift]
        else:
            raise ValueError('Bad case')

    return expr
Exemple #22
0
def extr(arg1, arg2, arg3, arg4):
    compose = ExprCompose(arg2, arg3)
    arg1 = compose[int(arg4.arg):int(arg4) + arg1.size]
Exemple #23
0
    def test_ClassDef(self):
        from miasm2.expression.expression import ExprInt, ExprId, ExprMem, \
            ExprCompose, ExprAff
        from miasm2.arch.x86.sem import ir_x86_32
        from miasm2.ir.symbexec import SymbolicExecutionEngine
        from miasm2.ir.ir import AssignBlock


        id_x = ExprId('x', 32)
        id_a = ExprId('a', 32)
        id_b = ExprId('b', 32)
        id_c = ExprId('c', 32)
        id_d = ExprId('d', 32)
        id_e = ExprId('e', 64)

        sb = SymbolicExecutionEngine(ir_x86_32(),
                                    {
                                        ExprMem(ExprInt(0x4, 32), 8): ExprInt(0x44, 8),
                                        ExprMem(ExprInt(0x5, 32), 8): ExprInt(0x33, 8),
                                        ExprMem(ExprInt(0x6, 32), 8): ExprInt(0x22, 8),
                                        ExprMem(ExprInt(0x7, 32), 8): ExprInt(0x11, 8),

                                        ExprMem(ExprInt(0x20, 32), 32): id_x,

                                        ExprMem(ExprInt(0x40, 32), 32): id_x,
                                        ExprMem(ExprInt(0x44, 32), 32): id_a,

                                        ExprMem(ExprInt(0x54, 32), 32): ExprInt(0x11223344, 32),

                                        ExprMem(id_a, 32): ExprInt(0x11223344, 32),
                                        id_a: ExprInt(0, 32),
                                        id_b: ExprInt(0, 32),

                                        ExprMem(id_c, 32): ExprMem(id_d + ExprInt(0x4, 32), 32),
                                        ExprMem(id_c + ExprInt(0x4, 32), 32): ExprMem(id_d + ExprInt(0x8, 32), 32),

                                    })


        self.assertEqual(sb.eval_expr(ExprInt(1, 32)-ExprInt(1, 32)), ExprInt(0, 32))

        ## Test with unknown mem + integer
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0, 32), 32)), ExprMem(ExprInt(0, 32), 32))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(1, 32), 32)), ExprCompose(ExprMem(ExprInt(1, 32), 24), ExprInt(0x44, 8)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(2, 32), 32)), ExprCompose(ExprMem(ExprInt(2, 32), 16), ExprInt(0x3344, 16)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(3, 32), 32)), ExprCompose(ExprMem(ExprInt(3, 32), 8), ExprInt(0x223344, 24)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(4, 32), 32)), ExprInt(0x11223344, 32))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(5, 32), 32)), ExprCompose(ExprInt(0x112233, 24), ExprMem(ExprInt(8, 32), 8)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(6, 32), 32)), ExprCompose(ExprInt(0x1122, 16), ExprMem(ExprInt(8, 32), 16)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(7, 32), 32)), ExprCompose(ExprInt(0x11, 8), ExprMem(ExprInt(8, 32), 24)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(8, 32), 32)), ExprMem(ExprInt(8, 32), 32))

        ## Test with unknown mem + integer
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x50, 32), 32)), ExprMem(ExprInt(0x50, 32), 32))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x51, 32), 32)), ExprCompose(ExprMem(ExprInt(0x51, 32), 24), ExprInt(0x44, 8)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x52, 32), 32)), ExprCompose(ExprMem(ExprInt(0x52, 32), 16), ExprInt(0x3344, 16)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x53, 32), 32)), ExprCompose(ExprMem(ExprInt(0x53, 32), 8), ExprInt(0x223344, 24)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x54, 32), 32)), ExprInt(0x11223344, 32))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x55, 32), 32)), ExprCompose(ExprInt(0x112233, 24), ExprMem(ExprInt(0x58, 32), 8)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x56, 32), 32)), ExprCompose(ExprInt(0x1122, 16), ExprMem(ExprInt(0x58, 32), 16)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x57, 32), 32)), ExprCompose(ExprInt(0x11, 8), ExprMem(ExprInt(0x58, 32), 24)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x58, 32), 32)), ExprMem(ExprInt(0x58, 32), 32))



        ## Test with unknown mem + id
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x1D, 32), 32)), ExprCompose(ExprMem(ExprInt(0x1D, 32), 24), id_x[:8]))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x1E, 32), 32)), ExprCompose(ExprMem(ExprInt(0x1E, 32), 16), id_x[:16]))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x1F, 32), 32)), ExprCompose(ExprMem(ExprInt(0x1F, 32), 8), id_x[:24]))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x20, 32), 32)), id_x)
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x21, 32), 32)), ExprCompose(id_x[8:], ExprMem(ExprInt(0x24, 32), 8)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x22, 32), 32)), ExprCompose(id_x[16:], ExprMem(ExprInt(0x24, 32), 16)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x23, 32), 32)), ExprCompose(id_x[24:], ExprMem(ExprInt(0x24, 32), 24)))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x24, 32), 32)), ExprMem(ExprInt(0x24, 32), 32))


        ## Partial read
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(4, 32), 8)), ExprInt(0x44, 8))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x20, 32), 8)), id_x[:8])
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x23, 32), 8)), id_x[24:])


        ## Merge
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x40, 32), 64)), ExprCompose(id_x, id_a))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x42, 32), 32)), ExprCompose(id_x[16:], id_a[:16]))

        # Merge memory
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x100, 32), 32)), ExprMem(ExprInt(0x100, 32), 32))
        self.assertEqual(sb.eval_expr(ExprMem(id_c + ExprInt(0x2, 32), 32)), ExprMem(id_d  + ExprInt(0x6, 32), 32))

        ## Func read
        def custom_func_read(mem):
            if mem == ExprMem(ExprInt(0x1000, 32), 32):
                return id_x
            return mem

        sb.func_read = custom_func_read

        ## Unmodified read
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(4, 32), 8)), ExprInt(0x44, 8))

        ## Modified read
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x1000, 32), 32)), id_x)


        ## Apply_change / eval_ir / apply_expr

        ## x = a (with a = 0x0)
        assignblk = AssignBlock({id_x:id_a})
        sb.eval_updt_assignblk(assignblk)
        self.assertEqual(sb.eval_expr(id_x), ExprInt(0, 32))

        ## x = a (without replacing 'a' with 0x0)
        sb.apply_change(id_x, id_a)
        self.assertEqual(sb.eval_expr(id_x), id_a)

        ## x = a (with a = 0x0)
        self.assertEqual(sb.eval_updt_expr(assignblk.dst2ExprAff(id_x)), ExprInt(0, 32))
        self.assertEqual(sb.eval_expr(id_x), ExprInt(0, 32))
        self.assertEqual(sb.eval_updt_expr(id_x), ExprInt(0, 32))

        sb.dump()

        ## state
        reads = set()
        for dst, src in sb.modified():
            reads.update(ExprAff(dst, src).get_r())

        self.assertEqual(reads, set([
            id_x, id_a,
            ExprMem(id_d + ExprInt(0x4, 32), 32),
            ExprMem(id_d + ExprInt(0x8, 32), 32),
        ]))

        # Erase low id_x byte with 0xFF
        sb.apply_change(ExprMem(ExprInt(0x20, 32), 8), ExprInt(0xFF, 8))
        state = dict(sb.modified(ids=False))
        self.assertEqual(state[ExprMem(ExprInt(0x20, 32), 8)], ExprInt(0xFF, 8))
        self.assertEqual(state[ExprMem(ExprInt(0x21, 32), 24)], id_x[8:32])

        # Erase high id_x byte with 0xEE
        sb.apply_change(ExprMem(ExprInt(0x23, 32), 8), ExprInt(0xEE, 8))

        state = dict(sb.modified(ids=False))
        self.assertEqual(state[ExprMem(ExprInt(0x20, 32), 8)], ExprInt(0xFF, 8))
        self.assertEqual(state[ExprMem(ExprInt(0x21, 32), 16)], id_x[8:24])
        self.assertEqual(state[ExprMem(ExprInt(0x23, 32), 8)], ExprInt(0xEE, 8))

        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x22, 32), 32)), ExprCompose(id_x[16:24], ExprInt(0xEE, 8), ExprMem(ExprInt(0x24, 32), 16)))

        # Erase low byte of 0x11223344 with 0xFF at 0x54
        sb.apply_change(ExprMem(ExprInt(0x54, 32), 8), ExprInt(0xFF, 8))

        # Erase low byte of 0x11223344 with 0xFF at id_a
        sb.apply_change(ExprMem(id_a + ExprInt(0x1, 32), 8), ExprInt(0xFF, 8))
        state = dict(sb.modified(ids=False))
        self.assertEqual(state[ExprMem(id_a + ExprInt(0x1, 32), 8)], ExprInt(0xFF, 8))
        self.assertEqual(state[ExprMem(id_a + ExprInt(0x2, 32), 16)], ExprInt(0x1122, 16))

        # Write uint32_t at 0xFFFFFFFE
        sb.apply_change(ExprMem(ExprInt(0xFFFFFFFE, 32), 32), ExprInt(0x11223344, 32))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0, 32), 16)), ExprInt(0x1122, 16))

        # Revert memory to original value at 0x42
        sb.apply_change(ExprMem(ExprInt(0x42, 32), 32), ExprMem(ExprInt(0x42, 32), 32))
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0x42, 32), 32)), ExprMem(ExprInt(0x42, 32), 32))

        # Revert memory to original value at c + 0x2
        sb.apply_change(ExprMem(id_c + ExprInt(0x2, 32), 32), ExprMem(id_c + ExprInt(0x2, 32), 32))
        self.assertEqual(sb.eval_expr(ExprMem(id_c + ExprInt(0x2, 32), 32)), ExprMem(id_c + ExprInt(0x2, 32), 32))

        # Test del symbol
        del sb.symbols[id_a]
        sb.dump()
        del sb.symbols[ExprMem(id_a, 8)]
        print "*"*40, 'Orig:'
        sb.dump()

        sb_cp = sb.symbols.copy()
        print "*"*40, 'Copy:'
        sb_cp.dump()

        # Add symbol at address limit
        sb.apply_change(ExprMem(ExprInt(0xFFFFFFFE, 32), 32), id_c)
        sb.dump()
        found = False
        for dst, src in sb.symbols.iteritems():
            if dst == ExprMem(ExprInt(0xFFFFFFFE, 32), 32) and src == id_c:
                found = True
        assert found


        # Add symbol at address limit
        sb.apply_change(ExprMem(ExprInt(0x7FFFFFFE, 32), 32), id_c)
        sb.dump()
        found = False
        for dst, src in sb.symbols.iteritems():
            if dst == ExprMem(ExprInt(0x7FFFFFFE, 32), 32) and src == id_c:
                found = True
        assert found



        # Add truncated symbol at address limit
        sb.apply_change(ExprMem(ExprInt(0xFFFFFFFC, 32), 64), id_e)
        # Revert parts of memory
        sb.apply_change(ExprMem(ExprInt(0xFFFFFFFC, 32), 16), ExprMem(ExprInt(0xFFFFFFFC, 32), 16))
        sb.apply_change(ExprMem(ExprInt(0x2, 32), 16), ExprMem(ExprInt(0x2, 32), 16))
        sb.dump()
        found = False
        for dst, src in sb.symbols.iteritems():
            if dst == ExprMem(ExprInt(0xFFFFFFFE, 32), 32) and src == id_e[16:48]:
                found = True
        assert found


        sb_empty = SymbolicExecutionEngine(ir_x86_32(), {})
        sb_empty.dump()


        # Test memory full
        print 'full'
        arch_addr8 = ir_x86_32()
        # Hack to obtain tiny address space
        arch_addr8.addrsize = 5
        sb_addr8 = SymbolicExecutionEngine(arch_addr8, {})
        sb_addr8.dump()
        # Fulfill memory
        sb_addr8.apply_change(ExprMem(ExprInt(0, 5), 256), ExprInt(0, 256))
        sb_addr8.dump()
        variables = sb_addr8.symbols.items()
        assert variables == [(ExprMem(ExprInt(0, 5), 256), ExprInt(0, 256))]

        print sb_addr8.symbols.symbols_mem

        sb_addr8.apply_change(ExprMem(ExprInt(0x5, 5), 256), ExprInt(0x123, 256))
        sb_addr8.dump()
        variables = sb_addr8.symbols.items()
        assert variables == [(ExprMem(ExprInt(0x5, 5), 256), ExprInt(0x123, 256))]
        print sb_addr8.symbols.symbols_mem

        print 'dump'
        sb_addr8.symbols.symbols_mem.dump()


        sb.dump()
        try:
            del sb.symbols.symbols_mem[ExprMem(ExprInt(0xFFFFFFFF, 32), 32)]
        except KeyError:
            # ok
            pass
        else:
            raise RuntimeError("Should raise error!")


        del sb.symbols.symbols_mem[ExprMem(ExprInt(0xFFFFFFFF, 32), 16)]
        sb.dump()
        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0xFFFFFFFE, 32), 32)),
                         ExprCompose(id_e[16:24], ExprMem(ExprInt(0xFFFFFFFF, 32), 16), id_e[40:48]))
        sb.symbols.symbols_mem.delete_partial(ExprMem(ExprInt(0xFFFFFFFF, 32), 32))

        self.assertEqual(sb.eval_expr(ExprMem(ExprInt(0xFFFFFFFE, 32), 32)),
                         ExprCompose(id_e[16:24], ExprMem(ExprInt(0xFFFFFFFF, 32), 24)))

        sb.dump()

        assert ExprMem(ExprInt(0xFFFFFFFE, 32), 8) in sb.symbols
        assert ExprMem(ExprInt(0xFFFFFFFE, 32), 32) not in sb.symbols
        assert sb.symbols.symbols_mem.contains_partial(ExprMem(ExprInt(0xFFFFFFFE, 32), 32))
        assert not sb.symbols.symbols_mem.contains_partial(ExprMem(ExprInt(0xFFFFFFFF, 32), 8))

        assert sb_addr8.symbols.keys() == [ExprMem(ExprInt(0x5, 5), 256)]
Exemple #24
0
from miasm2.expression.parser import str_to_expr
from miasm2.expression.expression import ExprInt, ExprId, ExprSlice, ExprMem, \
    ExprCond, ExprCompose, ExprOp, ExprAff, ExprLoc, LocKey

for expr_test in [
        ExprInt(0x12, 32),
        ExprId('test', 32),
        ExprLoc(LocKey(12), 32),
        ExprSlice(ExprInt(0x10, 32), 0, 8),
        ExprMem(ExprInt(0x10, 32), 32),
        ExprCond(ExprInt(0x10, 32), ExprInt(0x11, 32), ExprInt(0x12, 32)),
        ExprCompose(ExprInt(0x10, 16), ExprInt(0x11, 8), ExprInt(0x12, 8)),
        ExprInt(0x11, 8) + ExprInt(0x12, 8),
        ExprAff(ExprId('EAX', 32), ExprInt(0x12, 32)),
]:

    print 'Test: %s' % expr_test
    assert str_to_expr(repr(expr_test)) == expr_test
def simp_cst_propagation(e_s, expr):
    """This passe includes:
     - Constant folding
     - Common logical identities
     - Common binary identities
     """

    # merge associatif op
    args = list(expr.args)
    op_name = expr.op
    # simpl integer manip
    # int OP int => int
    # TODO: <<< >>> << >> are architecture dependant
    if op_name in op_propag_cst:
        while (len(args) >= 2 and args[-1].is_int() and args[-2].is_int()):
            int2 = args.pop()
            int1 = args.pop()
            if op_name == '+':
                out = int1.arg + int2.arg
            elif op_name == '*':
                out = int1.arg * int2.arg
            elif op_name == '**':
                out = int1.arg**int2.arg
            elif op_name == '^':
                out = int1.arg ^ int2.arg
            elif op_name == '&':
                out = int1.arg & int2.arg
            elif op_name == '|':
                out = int1.arg | int2.arg
            elif op_name == '>>':
                out = int1.arg >> int2.arg
            elif op_name == '<<':
                out = int1.arg << int2.arg
            elif op_name == 'a>>':
                tmp1 = mod_size2int[int1.arg.size](int1.arg)
                tmp2 = mod_size2uint[int2.arg.size](int2.arg)
                out = mod_size2uint[int1.arg.size](tmp1 >> tmp2)
            elif op_name == '>>>':
                out = (int1.arg >> (int2.arg % int2.size) | int1.arg <<
                       ((int1.size - int2.arg) % int2.size))
            elif op_name == '<<<':
                out = (int1.arg << (int2.arg % int2.size) | int1.arg >>
                       ((int1.size - int2.arg) % int2.size))
            elif op_name == '/':
                out = int1.arg / int2.arg
            elif op_name == '%':
                out = int1.arg % int2.arg
            elif op_name == 'idiv':
                assert int2.arg.arg
                tmp1 = mod_size2int[int1.arg.size](int1.arg)
                tmp2 = mod_size2int[int2.arg.size](int2.arg)
                out = mod_size2uint[int1.arg.size](tmp1 / tmp2)
            elif op_name == 'imod':
                assert int2.arg.arg
                tmp1 = mod_size2int[int1.arg.size](int1.arg)
                tmp2 = mod_size2int[int2.arg.size](int2.arg)
                out = mod_size2uint[int1.arg.size](tmp1 % tmp2)
            elif op_name == 'umod':
                assert int2.arg.arg
                tmp1 = mod_size2uint[int1.arg.size](int1.arg)
                tmp2 = mod_size2uint[int2.arg.size](int2.arg)
                out = mod_size2uint[int1.arg.size](tmp1 % tmp2)
            elif op_name == 'udiv':
                assert int2.arg.arg
                tmp1 = mod_size2uint[int1.arg.size](int1.arg)
                tmp2 = mod_size2uint[int2.arg.size](int2.arg)
                out = mod_size2uint[int1.arg.size](tmp1 / tmp2)

            args.append(ExprInt(out, int1.size))

    # bsf(int) => int
    if op_name == "bsf" and args[0].is_int() and args[0].arg != 0:
        i = 0
        while args[0].arg & (1 << i) == 0:
            i += 1
        return ExprInt(i, args[0].size)

    # bsr(int) => int
    if op_name == "bsr" and args[0].is_int() and args[0].arg != 0:
        i = args[0].size - 1
        while args[0].arg & (1 << i) == 0:
            i -= 1
        return ExprInt(i, args[0].size)

    # -(-(A)) => A
    if (op_name == '-' and len(args) == 1 and args[0].is_op('-')
            and len(args[0].args) == 1):
        return args[0].args[0]

    # -(int) => -int
    if op_name == '-' and len(args) == 1 and args[0].is_int():
        return ExprInt(-int(args[0]), expr.size)
    # A op 0 =>A
    if op_name in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1:
        if args[-1].is_int(0):
            args.pop()
    # A - 0 =>A
    if op_name == '-' and len(args) > 1 and args[-1].is_int(0):
        assert len(
            args) == 2  # Op '-' with more than 2 args: SantityCheckError
        return args[0]

    # A * 1 =>A
    if op_name == "*" and len(args) > 1 and args[-1].is_int(1):
        args.pop()

    # for cannon form
    # A * -1 => - A
    if op_name == "*" and len(args) > 1 and args[-1] == args[-1].mask:
        args.pop()
        args[-1] = -args[-1]

    # op A => A
    if op_name in [
            '+', '*', '^', '&', '|', '>>', '<<', 'a>>', '<<<', '>>>', 'idiv',
            'imod', 'umod', 'udiv'
    ] and len(args) == 1:
        return args[0]

    # A-B => A + (-B)
    if op_name == '-' and len(args) > 1:
        if len(args) > 2:
            raise ValueError(
                'sanity check fail on expr -: should have one or 2 args ' +
                '%r %s' % (expr, expr))
        return ExprOp('+', args[0], -args[1])

    # A op 0 => 0
    if op_name in ['&', "*"] and args[1].is_int(0):
        return ExprInt(0, expr.size)

    # - (A + B +...) => -A + -B + -C
    if op_name == '-' and len(args) == 1 and args[0].is_op('+'):
        args = [-a for a in args[0].args]
        return ExprOp('+', *args)

    # -(a?int1:int2) => (a?-int1:-int2)
    if (op_name == '-' and len(args) == 1 and args[0].is_cond()
            and args[0].src1.is_int() and args[0].src2.is_int()):
        int1 = args[0].src1
        int2 = args[0].src2
        int1 = ExprInt(-int1.arg, int1.size)
        int2 = ExprInt(-int2.arg, int2.size)
        return ExprCond(args[0].cond, int1, int2)

    i = 0
    while i < len(args) - 1:
        j = i + 1
        while j < len(args):
            # A ^ A => 0
            if op_name == '^' and args[i] == args[j]:
                args[i] = ExprInt(0, args[i].size)
                del args[j]
                continue
            # A + (- A) => 0
            if op_name == '+' and args[j].is_op("-"):
                if len(args[j].args) == 1 and args[i] == args[j].args[0]:
                    args[i] = ExprInt(0, args[i].size)
                    del args[j]
                    continue
            # (- A) + A => 0
            if op_name == '+' and args[i].is_op("-"):
                if len(args[i].args) == 1 and args[j] == args[i].args[0]:
                    args[i] = ExprInt(0, args[i].size)
                    del args[j]
                    continue
            # A | A => A
            if op_name == '|' and args[i] == args[j]:
                del args[j]
                continue
            # A & A => A
            if op_name == '&' and args[i] == args[j]:
                del args[j]
                continue
            j += 1
        i += 1

    if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1:
        return args[0]

    # A <<< A.size => A
    if (op_name in ['<<<', '>>>'] and args[1].is_int()
            and args[1].arg == args[0].size):
        return args[0]

    # A <<< X <<< Y => A <<< (X+Y) (ou <<< >>>)
    if (op_name in ['<<<', '>>>'] and args[0].is_op()
            and args[0].op in ['<<<', '>>>']):
        op1 = op_name
        op2 = args[0].op
        if op1 == op2:
            op_name = op1
            args1 = args[0].args[1] + args[1]
        else:
            op_name = op2
            args1 = args[0].args[1] - args[1]

        args0 = args[0].args[0]
        args = [args0, args1]

    # A >> X >> Y  =>  A >> (X+Y)
    if (op_name in ['<<', '>>'] and args[0].is_op(op_name)):
        args = [args[0].args[0], args[0].args[1] + args[1]]

    # ((A & A.mask)
    if op_name == "&" and args[-1] == expr.mask:
        return ExprOp('&', *args[:-1])

    # ((A | A.mask)
    if op_name == "|" and args[-1] == expr.mask:
        return args[-1]

    # ! (!X + int) => X - int
    # TODO

    # ((A & mask) >> shift) whith mask < 2**shift => 0
    if op_name == ">>" and args[1].is_int() and args[0].is_op("&"):
        if (args[0].args[1].is_int() and 2**args[1].arg > args[0].args[1].arg):
            return ExprInt(0, args[0].size)

    # parity(int) => int
    if op_name == 'parity' and args[0].is_int():
        return ExprInt(parity(int(args[0])), 1)

    # (-a) * b * (-c) * (-d) => (-a) * b * c * d
    if op_name == "*" and len(args) > 1:
        new_args = []
        counter = 0
        for arg in args:
            if arg.is_op('-') and len(arg.args) == 1:
                new_args.append(arg.args[0])
                counter += 1
            else:
                new_args.append(arg)
        if counter % 2:
            return -ExprOp(op_name, *new_args)
        args = new_args

    # A << int with A ExprCompose => move index
    if (op_name == "<<" and args[0].is_compose() and args[1].is_int()
            and int(args[1]) != 0):
        final_size = args[0].size
        shift = int(args[1])
        new_args = []
        # shift indexes
        for index, arg in args[0].iter_args():
            new_args.append((arg, index + shift, index + shift + arg.size))
        # filter out expression
        filter_args = []
        min_index = final_size
        for tmp, start, stop in new_args:
            if start >= final_size:
                continue
            if stop > final_size:
                tmp = tmp[:tmp.size - (stop - final_size)]
                stop = final_size
            filter_args.append(tmp)
            min_index = min(start, min_index)
        # create entry 0
        assert min_index != 0
        tmp = ExprInt(0, min_index)
        args = [tmp] + filter_args
        return ExprCompose(*args)

    # A >> int with A ExprCompose => move index
    if op_name == ">>" and args[0].is_compose() and args[1].is_int():
        final_size = args[0].size
        shift = int(args[1])
        new_args = []
        # shift indexes
        for index, arg in args[0].iter_args():
            new_args.append((arg, index - shift, index + arg.size - shift))
        # filter out expression
        filter_args = []
        max_index = 0
        for tmp, start, stop in new_args:
            if stop <= 0:
                continue
            if start < 0:
                tmp = tmp[-start:]
                start = 0
            filter_args.append(tmp)
            max_index = max(stop, max_index)
        # create entry 0
        tmp = ExprInt(0, final_size - max_index)
        args = filter_args + [tmp]
        return ExprCompose(*args)

    # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b)
    if op_name in ['|', '&', '^'] and all([arg.is_compose() for arg in args]):
        bounds = set()
        for arg in args:
            bound = tuple([tmp.size for tmp in arg.args])
            bounds.add(bound)
        if len(bounds) == 1:
            bound = list(bounds)[0]
            new_args = [[tmp] for tmp in args[0].args]
            for sub_arg in args[1:]:
                for i, tmp in enumerate(sub_arg.args):
                    new_args[i].append(tmp)
            args = []
            for i, arg in enumerate(new_args):
                args.append(ExprOp(op_name, *arg))
            return ExprCompose(*args)

    # <<<c_rez, >>>c_rez
    if op_name in [">>>c_rez", "<<<c_rez"]:
        assert len(args) == 3
        dest, rounds, carry_flag = args
        # Skipped if rounds is 0
        if rounds.is_int(0):
            return dest
        elif all(arg.is_int() for arg in args):
            # The expression can be resolved
            tmp = int(dest)
            carry_flag = int(carry_flag)
            size = dest.size
            tmp_count = (int(rounds) &
                         (0x3f if size == 64 else 0x1f)) % (size + 1)
            if op_name == ">>>c_rez":
                while tmp_count != 0:
                    tmp_cf = tmp & 1
                    tmp = (tmp >> 1) + (carry_flag << (size - 1))
                    carry_flag = tmp_cf
                    tmp_count -= 1
                    tmp &= int(dest.mask)
            elif op_name == "<<<c_rez":
                while tmp_count != 0:
                    tmp_cf = (tmp >> (size - 1)) & 1
                    tmp = (tmp << 1) + carry_flag
                    carry_flag = tmp_cf
                    tmp_count -= 1
                    tmp &= int(dest.mask)
            else:
                raise RuntimeError("Unknown operation: %s" % op_name)
            return ExprInt(tmp, size=dest.size)

    return ExprOp(op_name, *args)
def simp_cst_propagation(e_s, expr):
    """This passe includes:
     - Constant folding
     - Common logical identities
     - Common binary identities
     """

    # merge associatif op
    args = list(expr.args)
    op_name = expr.op
    # simpl integer manip
    # int OP int => int
    # TODO: <<< >>> << >> are architecture dependant
    if op_name in op_propag_cst:
        while (len(args) >= 2 and
            args[-1].is_int() and
            args[-2].is_int()):
            int2 = args.pop()
            int1 = args.pop()
            if op_name == '+':
                out = int1.arg + int2.arg
            elif op_name == '*':
                out = int1.arg * int2.arg
            elif op_name == '**':
                out =int1.arg ** int2.arg
            elif op_name == '^':
                out = int1.arg ^ int2.arg
            elif op_name == '&':
                out = int1.arg & int2.arg
            elif op_name == '|':
                out = int1.arg | int2.arg
            elif op_name == '>>':
                if int(int2) > int1.size:
                    out = 0
                else:
                    out = int1.arg >> int2.arg
            elif op_name == '<<':
                if int(int2) > int1.size:
                    out = 0
                else:
                    out = int1.arg << int2.arg
            elif op_name == 'a>>':
                tmp1 = mod_size2int[int1.arg.size](int1.arg)
                tmp2 = mod_size2uint[int2.arg.size](int2.arg)
                if tmp2 > int1.size:
                    is_signed = int(int1) & (1 << (int1.size - 1))
                    if is_signed:
                        out = -1
                    else:
                        out = 0
                else:
                    out = mod_size2uint[int1.arg.size](tmp1 >> tmp2)
            elif op_name == '>>>':
                shifter = int2.arg % int2.size
                out = (int1.arg >> shifter) | (int1.arg << (int2.size - shifter))
            elif op_name == '<<<':
                shifter = int2.arg % int2.size
                out = (int1.arg << shifter) | (int1.arg >> (int2.size - shifter))
            elif op_name == '/':
                out = int1.arg / int2.arg
            elif op_name == '%':
                out = int1.arg % int2.arg
            elif op_name == 'idiv':
                assert int2.arg.arg
                tmp1 = mod_size2int[int1.arg.size](int1.arg)
                tmp2 = mod_size2int[int2.arg.size](int2.arg)
                out = mod_size2uint[int1.arg.size](tmp1 / tmp2)
            elif op_name == 'imod':
                assert int2.arg.arg
                tmp1 = mod_size2int[int1.arg.size](int1.arg)
                tmp2 = mod_size2int[int2.arg.size](int2.arg)
                out = mod_size2uint[int1.arg.size](tmp1 % tmp2)
            elif op_name == 'umod':
                assert int2.arg.arg
                tmp1 = mod_size2uint[int1.arg.size](int1.arg)
                tmp2 = mod_size2uint[int2.arg.size](int2.arg)
                out = mod_size2uint[int1.arg.size](tmp1 % tmp2)
            elif op_name == 'udiv':
                assert int2.arg.arg
                tmp1 = mod_size2uint[int1.arg.size](int1.arg)
                tmp2 = mod_size2uint[int2.arg.size](int2.arg)
                out = mod_size2uint[int1.arg.size](tmp1 / tmp2)



            args.append(ExprInt(out, int1.size))

    # cnttrailzeros(int) => int
    if op_name == "cnttrailzeros" and args[0].is_int():
        i = 0
        while args[0].arg & (1 << i) == 0 and i < args[0].size:
            i += 1
        return ExprInt(i, args[0].size)

    # cntleadzeros(int) => int
    if op_name == "cntleadzeros" and args[0].is_int():
        if args[0].arg == 0:
            return ExprInt(args[0].size, args[0].size)
        i = args[0].size - 1
        while args[0].arg & (1 << i) == 0:
            i -= 1
        return ExprInt(expr.size - (i + 1), args[0].size)

    # -(-(A)) => A
    if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and
        len(args[0].args) == 1):
        return args[0].args[0]

    # -(int) => -int
    if op_name == '-' and len(args) == 1 and args[0].is_int():
        return ExprInt(-int(args[0]), expr.size)
    # A op 0 =>A
    if op_name in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1:
        if args[-1].is_int(0):
            args.pop()
    # A - 0 =>A
    if op_name == '-' and len(args) > 1 and args[-1].is_int(0):
        assert len(args) == 2 # Op '-' with more than 2 args: SantityCheckError
        return args[0]

    # A * 1 =>A
    if op_name == "*" and len(args) > 1 and args[-1].is_int(1):
        args.pop()

    # for cannon form
    # A * -1 => - A
    if op_name == "*" and len(args) > 1 and args[-1] == args[-1].mask:
        args.pop()
        args[-1] = - args[-1]

    # op A => A
    if op_name in ['+', '*', '^', '&', '|', '>>', '<<',
              'a>>', '<<<', '>>>', 'idiv', 'imod', 'umod', 'udiv'] and len(args) == 1:
        return args[0]

    # A-B => A + (-B)
    if op_name == '-' and len(args) > 1:
        if len(args) > 2:
            raise ValueError(
                'sanity check fail on expr -: should have one or 2 args ' +
                '%r %s' % (expr, expr))
        return ExprOp('+', args[0], -args[1])

    # A op 0 => 0
    if op_name in ['&', "*"] and args[-1].is_int(0):
        return ExprInt(0, expr.size)

    # - (A + B +...) => -A + -B + -C
    if op_name == '-' and len(args) == 1 and args[0].is_op('+'):
        args = [-a for a in args[0].args]
        return ExprOp('+', *args)

    # -(a?int1:int2) => (a?-int1:-int2)
    if (op_name == '-' and len(args) == 1 and
        args[0].is_cond() and
        args[0].src1.is_int() and args[0].src2.is_int()):
        int1 = args[0].src1
        int2 = args[0].src2
        int1 = ExprInt(-int1.arg, int1.size)
        int2 = ExprInt(-int2.arg, int2.size)
        return ExprCond(args[0].cond, int1, int2)

    i = 0
    while i < len(args) - 1:
        j = i + 1
        while j < len(args):
            # A ^ A => 0
            if op_name == '^' and args[i] == args[j]:
                args[i] = ExprInt(0, args[i].size)
                del args[j]
                continue
            # A + (- A) => 0
            if op_name == '+' and args[j].is_op("-"):
                if len(args[j].args) == 1 and args[i] == args[j].args[0]:
                    args[i] = ExprInt(0, args[i].size)
                    del args[j]
                    continue
            # (- A) + A => 0
            if op_name == '+' and args[i].is_op("-"):
                if len(args[i].args) == 1 and args[j] == args[i].args[0]:
                    args[i] = ExprInt(0, args[i].size)
                    del args[j]
                    continue
            # A | A => A
            if op_name == '|' and args[i] == args[j]:
                del args[j]
                continue
            # A & A => A
            if op_name == '&' and args[i] == args[j]:
                del args[j]
                continue
            j += 1
        i += 1

    if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1:
        return args[0]

    # A <<< A.size => A
    if (op_name in ['<<<', '>>>'] and
        args[1].is_int() and
        args[1].arg == args[0].size):
        return args[0]

    # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow
    if (op_name in ['<<<', '>>>'] and
        args[0].is_op() and
        args[0].op in ['<<<', '>>>']):
        A = args[0].args[0]
        X = args[0].args[1]
        Y = args[1]
        if op_name != args[0].op and e_s(X - Y) == ExprInt(0, X.size):
            return args[0].args[0]
        elif X.is_int() and Y.is_int():
            new_X = int(X) % expr.size
            new_Y = int(Y) % expr.size
            if op_name == args[0].op:
                rot = (new_X + new_Y) % expr.size
                op = op_name
            else:
                rot = new_Y - new_X
                op = op_name
                if rot < 0:
                    rot = - rot
                    op = {">>>": "<<<", "<<<": ">>>"}[op_name]
            args = [A, ExprInt(rot, expr.size)]
            op_name = op

        else:
            # Do not consider this case, too tricky (overflow on addition /
            # substraction)
            pass

    # A >> X >> Y  =>  A >> (X+Y) if X + Y does not overflow
    # To be sure, only consider the simplification when X.msb and Y.msb are 0
    if (op_name in ['<<', '>>'] and
        args[0].is_op(op_name)):
        X = args[0].args[1]
        Y = args[1]
        if (e_s(X.msb()) == ExprInt(0, 1) and
            e_s(Y.msb()) == ExprInt(0, 1)):
            args = [args[0].args[0], X + Y]

    # ((var >> int1) << int1) => var & mask
    # ((var << int1) >> int1) => var & mask
    if (op_name in ['<<', '>>'] and
        args[0].is_op() and
        args[0].op in ['<<', '>>'] and
        op_name != args[0]):
        var = args[0].args[0]
        int1 = args[0].args[1]
        int2 = args[1]
        if int1 == int2 and int1.is_int() and int(int1) < expr.size:
            if op_name == '>>':
                mask = ExprInt((1 << (expr.size - int(int1))) - 1, expr.size)
            else:
                mask = ExprInt(
                    ((1 << int(int1)) - 1) ^ ((1 << expr.size) - 1),
                    expr.size
                )
            ret = var & mask
            return ret

    # ((A & A.mask)
    if op_name == "&" and args[-1] == expr.mask:
        return ExprOp('&', *args[:-1])

    # ((A | A.mask)
    if op_name == "|" and args[-1] == expr.mask:
        return args[-1]

    # ! (!X + int) => X - int
    # TODO

    # ((A & mask) >> shift) whith mask < 2**shift => 0
    if op_name == ">>" and args[1].is_int() and args[0].is_op("&"):
        if (args[0].args[1].is_int() and
            2 ** args[1].arg > args[0].args[1].arg):
            return ExprInt(0, args[0].size)

    # parity(int) => int
    if op_name == 'parity' and args[0].is_int():
        return ExprInt(parity(int(args[0])), 1)

    # (-a) * b * (-c) * (-d) => (-a) * b * c * d
    if op_name == "*" and len(args) > 1:
        new_args = []
        counter = 0
        for arg in args:
            if arg.is_op('-') and len(arg.args) == 1:
                new_args.append(arg.args[0])
                counter += 1
            else:
                new_args.append(arg)
        if counter % 2:
            return -ExprOp(op_name, *new_args)
        args = new_args

    # A << int with A ExprCompose => move index
    if (op_name == "<<" and args[0].is_compose() and
        args[1].is_int() and int(args[1]) != 0):
        final_size = args[0].size
        shift = int(args[1])
        new_args = []
        # shift indexes
        for index, arg in args[0].iter_args():
            new_args.append((arg, index+shift, index+shift+arg.size))
        # filter out expression
        filter_args = []
        min_index = final_size
        for tmp, start, stop in new_args:
            if start >= final_size:
                continue
            if stop > final_size:
                tmp = tmp[:tmp.size  - (stop - final_size)]
                stop = final_size
            filter_args.append(tmp)
            min_index = min(start, min_index)
        # create entry 0
        assert min_index != 0
        tmp = ExprInt(0, min_index)
        args = [tmp] + filter_args
        return ExprCompose(*args)

    # A >> int with A ExprCompose => move index
    if op_name == ">>" and args[0].is_compose() and args[1].is_int():
        final_size = args[0].size
        shift = int(args[1])
        new_args = []
        # shift indexes
        for index, arg in args[0].iter_args():
            new_args.append((arg, index-shift, index+arg.size-shift))
        # filter out expression
        filter_args = []
        max_index = 0
        for tmp, start, stop in new_args:
            if stop <= 0:
                continue
            if start < 0:
                tmp = tmp[-start:]
                start = 0
            filter_args.append(tmp)
            max_index = max(stop, max_index)
        # create entry 0
        tmp = ExprInt(0, final_size - max_index)
        args = filter_args + [tmp]
        return ExprCompose(*args)


    # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b)
    if op_name in ['|', '&', '^'] and all([arg.is_compose() for arg in args]):
        bounds = set()
        for arg in args:
            bound = tuple([tmp.size for tmp in arg.args])
            bounds.add(bound)
        if len(bounds) == 1:
            bound = list(bounds)[0]
            new_args = [[tmp] for tmp in args[0].args]
            for sub_arg in args[1:]:
                for i, tmp in enumerate(sub_arg.args):
                    new_args[i].append(tmp)
            args = []
            for i, arg in enumerate(new_args):
                args.append(ExprOp(op_name, *arg))
            return ExprCompose(*args)

    return ExprOp(op_name, *args)
Exemple #27
0
    def apply_expr_on_state_visit_cache(self, expr, state, cache, level=0):
        """
        Deep First evaluate nodes:
            1. evaluate node's sons
            2. simplify
        """

        if expr in cache:
            ret = cache[expr]
        elif expr in state:
            return state[expr]
        elif expr.is_int():
            ret = expr
        elif expr.is_id():
            if isinstance(expr.name,
                          asmblock.asm_label) and expr.name.offset is not None:
                ret = ExprInt(expr.name.offset, expr.size)
            elif expr in self.regstop:
                ret = exprid_top(expr)
            else:
                ret = state.get(expr, expr)
        elif expr.is_mem():
            ret = self.manage_mem(expr, state, cache, level)
        elif expr.is_cond():
            cond = self.apply_expr_on_state_visit_cache(
                expr.cond, state, cache, level + 1)
            src1 = self.apply_expr_on_state_visit_cache(
                expr.src1, state, cache, level + 1)
            src2 = self.apply_expr_on_state_visit_cache(
                expr.src2, state, cache, level + 1)
            if cond.is_id(TOPSTR) or src1.is_id(TOPSTR) or src2.is_id(TOPSTR):
                ret = exprid_top(expr)
            else:
                ret = ExprCond(cond, src1, src2)
        elif expr.is_slice():
            arg = self.apply_expr_on_state_visit_cache(expr.arg, state, cache,
                                                       level + 1)
            if arg.is_id(TOPSTR):
                ret = exprid_top(expr)
            else:
                ret = ExprSlice(arg, expr.start, expr.stop)
        elif expr.is_op():
            args = []
            for oarg in expr.args:
                arg = self.apply_expr_on_state_visit_cache(
                    oarg, state, cache, level + 1)
                assert oarg.size == arg.size
                if arg.is_id(TOPSTR):
                    return exprid_top(expr)
                args.append(arg)
            ret = ExprOp(expr.op, *args)
        elif expr.is_compose():
            args = []
            for arg in expr.args:
                arg = self.apply_expr_on_state_visit_cache(
                    arg, state, cache, level + 1)
                if arg.is_id(TOPSTR):
                    return exprid_top(expr)

                args.append(arg)
            ret = ExprCompose(*args)
        else:
            raise TypeError("Unknown expr type")
        ret = self.expr_simp(ret)
        assert expr.size == ret.size
        cache[expr] = ret
        return ret
Exemple #28
0
    def from_ExprOp(self, expr):
        if len(expr.args) == 1:
            if expr.op == 'parity':
                arg = expr.args[0]
                out = self.from_expr(arg)
                if arg.size <= self.NATIVE_INT_MAX_SIZE:
                    out = "(%s&%s)" % (out, self._size2mask(arg.size))
                else:
                    out = 'bignum_mask(%s, 8)' % (out, 8)
                    out = 'bignum_to_uint64(%s)' % out
                out = 'parity(%s)' % out
                return out

            elif expr.op.startswith("zeroExt_"):
                arg = expr.args[0]
                if expr.size == arg.size:
                    return arg
                return self.from_expr(
                    ExprCompose(arg, ExprInt(0, expr.size - arg.size)))

            elif expr.op.startswith("signExt_"):
                arg = expr.args[0]
                if expr.size == arg.size:
                    return arg
                add_size = expr.size - arg.size
                new_expr = ExprCompose(
                    arg,
                    ExprCond(arg.msb(), ExprInt(size2mask(add_size), add_size),
                             ExprInt(0, add_size)))
                return self.from_expr(new_expr)

            elif expr.op in ['cntleadzeros', 'cnttrailzeros']:
                arg = expr.args[0]
                out = self.from_expr(arg)
                if arg.size <= self.NATIVE_INT_MAX_SIZE:
                    out = "%s(0x%x, %s)" % (expr.op, expr.args[0].size, out)
                else:
                    out = "bignum_%s(%s, %d)" % (expr.op, out, arg.size)
                return out

            elif expr.op == '!':
                arg = expr.args[0]
                out = self.from_expr(arg)
                if expr.size <= self.NATIVE_INT_MAX_SIZE:
                    out = "(~ %s)&%s" % (out, self._size2mask(arg.size))
                else:
                    out = "bignum_not(%s)" % out
                    out = "bignum_mask(%s, expr.size)" % out
                return out

            elif expr.op in [
                    "ftan",
                    "frndint",
                    "f2xm1",
                    "fsin",
                    "fsqrt",
                    "fabs",
                    "fcos",
                    "fchs",
            ]:
                return "fpu_%s%d(%s)" % (
                    expr.op,
                    expr.size,
                    self.from_expr(expr.args[0]),
                )
            elif (expr.op.startswith("access_") or expr.op.startswith("load_")
                  or expr.op.startswith("fxam_c")):
                arg = expr.args[0]
                out = self.from_expr(arg)
                out = "%s(%s)" % (expr.op, out)
                return out

            elif expr.op == "-":
                arg = expr.args[0]
                out = self.from_expr(arg)
                if arg.size <= self.NATIVE_INT_MAX_SIZE:
                    out = "(%s(%s))" % (expr.op, out)
                    out = "(%s&%s)" % (out, self._size2mask(arg.size))
                else:
                    out = "bignum_sub(bignum_from_uint64(0), %s)" % out
                    out = "bignum_mask(%s, %d)" % (out, expr.size)
                return out

            elif expr.op.startswith("fpround_"):
                return "%s_fp%d(%s)" % (
                    expr.op,
                    expr.size,
                    self.from_expr(expr.args[0]),
                )
            elif expr.op == "sint_to_fp":
                size = expr.size
                arg = expr.args[0]
                if size not in [32, 64]:
                    raise RuntimeError("Unsupported size for sint_to_fp: %r" %
                                       size)
                return "%s_%d(%s)" % (expr.op, size, self.from_expr(arg))
            elif expr.op.startswith("fp_to_sint"):
                dest_size = expr.size
                arg_size = expr.args[0].size
                if (arg_size, dest_size) in [
                    (32, 32),
                    (64, 64),
                    (64, 32),
                ]:
                    func = "fp%d_to_sint%d" % (arg_size, dest_size)
                else:
                    raise RuntimeError(
                        "Unsupported size for fp_to_sint: %r to %r" %
                        (arg_size, dest_size))
                return "%s(%s)" % (func, self.from_expr(expr.args[0]))
            elif expr.op.startswith("fpconvert_fp"):
                dest_size = expr.size
                arg_size = expr.args[0].size
                if (arg_size, dest_size) in [(32, 64), (64, 32)]:
                    func = "fp%d_to_fp%d" % (arg_size, dest_size)
                else:
                    raise RuntimeError(
                        "Unsupported size for fpconvert: %r to %r" %
                        (arg_size, dest_size))
                return "%s(%s)" % (func, self.from_expr(expr.args[0]))
            else:
                raise NotImplementedError('Unknown op: %r' % expr.op)

        elif len(expr.args) == 2:
            if expr.op == TOK_EQUAL:
                return '(((%s&%s) == (%s&%s))?1:0)' % (
                    self.from_expr(expr.args[0]),
                    self._size2mask(expr.args[0].size),
                    self.from_expr(expr.args[1]),
                    self._size2mask(expr.args[1].size),
                )
            elif expr.op in self.dct_shift:
                arg0 = self.from_expr(expr.args[0])
                arg1 = self.from_expr(expr.args[1])
                if expr.size <= self.NATIVE_INT_MAX_SIZE:
                    out = 'SHIFT_%s(%d, %s, %s)' % (self.dct_shift[
                        expr.op].upper(), expr.args[0].size, arg0, arg1)
                else:
                    op = {"<<": "lshift", ">>": "rshift", "a>>": "a_rshift"}
                    out = "bignum_%s(%s, bignum_to_uint64(%s))" % (op[expr.op],
                                                                   arg0, arg1)
                    out = "bignum_mask(%s, %d)" % (out, expr.size)
                return out

            elif expr.is_associative():
                args = [self.from_expr(arg) for arg in expr.args]
                if expr.size <= self.NATIVE_INT_MAX_SIZE:
                    out = (" %s " % expr.op).join(args)
                    out = "((%s)&%s)" % (out, self._size2mask(expr.size))
                else:
                    op_to_bn_func = {
                        "+": "add",
                        "*": "mul",
                        "|": "or",
                        "^": "xor",
                        "&": "and",
                    }
                    args = list(expr.args)
                    out = self.from_expr(args.pop())
                    while args:
                        out = 'bignum_mask(bignum_%s(%s, %s), %d)' % (
                            op_to_bn_func[expr.op], out,
                            self.from_expr(args.pop()), expr.size)
                return out

            elif expr.op in ['-']:
                return '(((%s&%s) %s (%s&%s))&%s)' % (
                    self.from_expr(expr.args[0]),
                    self._size2mask(expr.args[0].size), str(
                        expr.op), self.from_expr(
                            expr.args[1]), self._size2mask(expr.args[1].size),
                    self._size2mask(expr.args[0].size))
            elif expr.op in self.dct_rot:
                arg0 = self.from_expr(expr.args[0])
                arg1 = self.from_expr(expr.args[1])
                if expr.size <= self.NATIVE_INT_MAX_SIZE:
                    out = '(%s(%s, %s, %s) &%s)' % (
                        self.dct_rot[expr.op],
                        expr.args[0].size,
                        arg0,
                        arg1,
                        self._size2mask(expr.args[0].size),
                    )
                else:
                    op = {">>>": "ror", "<<<": "rol"}
                    out = "bignum_%s(%s, %d, bignum_to_uint64(%s))" % (
                        op[expr.op], arg0, expr.size, arg1)
                    out = "bignum_mask(%s, %d)" % (out, expr.size)
                return out

            elif expr.op == 'x86_cpuid':
                return "%s(%s, %s)" % (expr.op, self.from_expr(
                    expr.args[0]), self.from_expr(expr.args[1]))
            elif expr.op.startswith("fcom"):
                arg0 = self.from_expr(expr.args[0])
                arg1 = self.from_expr(expr.args[1])
                if not expr.args[0].size <= self.NATIVE_INT_MAX_SIZE:
                    raise ValueError(
                        "Bad semantic: fpu do operations do not support such size"
                    )
                out = "fpu_%s(%s, %s)" % (expr.op, arg0, arg1)
                return out

            elif expr.op in [
                    "fadd", "fsub", "fdiv", 'fmul', "fscale", "fprem", "fyl2x",
                    "fpatan"
            ]:
                arg0 = self.from_expr(expr.args[0])
                arg1 = self.from_expr(expr.args[1])
                if not expr.args[0].size <= self.NATIVE_INT_MAX_SIZE:
                    raise ValueError(
                        "Bad semantic: fpu do operations do not support such size"
                    )
                out = "fpu_%s%d(%s, %s)" % (expr.op, expr.size, arg0, arg1)
                return out

            elif expr.op == "segm":
                return "segm2addr(jitcpu, %s, %s)" % (self.from_expr(
                    expr.args[0]), self.from_expr(expr.args[1]))

            elif expr.op in ['udiv', 'umod']:
                arg0 = self.from_expr(expr.args[0])
                arg1 = self.from_expr(expr.args[1])

                if expr.size <= self.NATIVE_INT_MAX_SIZE:
                    out = '%s%d(%s, %s)' % (expr.op, expr.args[0].size, arg0,
                                            arg1)
                else:
                    out = "bignum_%s(%s, %s)" % (expr.op, arg0, arg1)
                    out = "bignum_mask(%s, %d)" % (out, expr.size)
                return out

            elif expr.op in ['sdiv', 'smod']:
                arg0 = self.from_expr(expr.args[0])
                arg1 = self.from_expr(expr.args[1])

                if expr.size <= self.NATIVE_INT_MAX_SIZE:
                    out = '%s%d(%s, %s)' % (expr.op, expr.args[0].size, arg0,
                                            arg1)
                else:
                    out = "bignum_%s(%s, %s, %d)" % (expr.op, arg0, arg1,
                                                     expr.size)
                    out = "bignum_mask(%s, %d)" % (out, expr.size)
                return out

            elif expr.op in ["bcdadd", "bcdadd_cf"]:
                return "%s_%d(%s, %s)" % (expr.op, expr.args[0].size,
                                          self.from_expr(expr.args[0]),
                                          self.from_expr(expr.args[1]))
            else:
                raise NotImplementedError('Unknown op: %r' % expr.op)

        elif len(expr.args) >= 3 and expr.is_associative():  # ?????
            oper = [
                '(%s&%s)' % (
                    self.from_expr(arg),
                    self._size2mask(arg.size),
                ) for arg in expr.args
            ]
            oper = str(expr.op).join(oper)
            return "((%s)&%s)" % (oper, self._size2mask(expr.args[0].size))
        else:
            raise NotImplementedError('Unknown op: %s' % expr.op)
Exemple #29
0
import pickle
from miasm2.expression.expression import ExprInt, ExprAff, ExprId, \
    Expr, ExprCompose, ExprMem


a = ExprId("test", 8)
b = ExprInt(1338, 8)
c = a + b
d = ExprCompose(a, b)
e = ExprMem(a, 32)
f = a[:8]
aff = ExprAff(a, b)


print 'Pickling'
out = pickle.dumps((a, b, c, d, e, f, aff))
print 'Unpickling'
new_a, new_b, new_c, new_d, new_e, new_f, new_aff = pickle.loads(out)
print 'Result'
print a, b, c, aff
print id(a), id(b), id(c), id(d), id(e), id(f), id(aff)
print new_a, new_b, new_c, new_d, new_e, new_f, new_aff
print id(new_a), id(new_b), id(new_c), id(new_d), id(new_e), id(new_f), id(new_aff)

assert a == new_a
assert b == new_b
assert c == new_c
assert d == new_d
assert e == new_e
assert f == new_f
assert aff == new_aff