예제 #1
0
def signed_lt(op1, op2, state):
    v1, v2, t1, t2 = int_unif(op1, op2, state)

    if t1.signed == t2.signed:
        return claripy.SLT(v1, v2) if t1.signed else claripy.ULT(v1, v2)

    return claripy.If(claripy.SLT(v1 if t1.signed else v2, 0), t1.signed, v1 < v2)
예제 #2
0
def signed_eq(op1, op2, state):
    v1, v2, t1, t2 = int_unif(op1, op2, state)

    if t1.signed == t2.signed:
        return v1 == v2

    return claripy.If(claripy.SLT(v1 if t1.signed else v2, 0), False, v1 == v2)
예제 #3
0
 def _op_generic_CmpORD(self, args):
     x = args[0]
     y = args[1]
     s = self._from_size
     cond = claripy.SLT(x, y) if self.is_signed else claripy.ULT(x, y)
     return claripy.If(
         x == y, claripy.BVV(0x2, s),
         claripy.If(cond, claripy.BVV(0x8, s), claripy.BVV(0x4, s)))
예제 #4
0
 def _op_generic_StoU_saturation(self, value, min_value, max_value):  #pylint:disable=no-self-use
     """
     Return unsigned saturated BV from signed BV.
     Min and max value should be unsigned.
     """
     return claripy.If(
         claripy.SGT(value, max_value), max_value,
         claripy.If(claripy.SLT(value, min_value), min_value, value))
예제 #5
0
 def mark_boundaries(self, var_1, var_2):
     key = (var_1.conc_addr, var_2.conc_addr)
     if key in self.bounds_marked:
         return
     self.bounds_marked.add(key)
     diff = var_2.conc_addr - var_1.conc_addr
     self.solver.add(claripy.SLE(var_1.sym_addr + diff, var_2.sym_addr))
     self.stack.unsafe_constraints.append(
         claripy.SLT(var_1.sym_addr + diff, var_2.sym_addr))
예제 #6
0
def compute_mem_accesses(project, g, init_state, final_state):
    mem = set()
    # Is the memory access performed through a simple dereferentiation? es: mov n, [REG]
    simple_accesses = True
    # TODO: ast must be created from a symbolic state where registers values are named "sreg_REG-"
    for a in chain(final_state.history.filter_actions(read_from=ANGR_MEM),
                   final_state.history.filter_actions(write_to=ANGR_MEM)):
        # TODO: for now only check regs from which depends
        for var in a.addr.ast.variables:
            if var.startswith("sreg_"):
                # get the name of the register from symbolic name, previously initialized as sreg_REG-
                try:
                    mem.add(Arch.Registers[var[5:].split("-")[0]])
                except KeyError:
                    mem.add(Arch.UnknownType.unknown)
            elif var.startswith("symbolic_stack"):
                mem.add(Arch.MemType.stack)
            else:
                mem.add(Arch.UnknownType.unknown)
        if a.addr.ast.symbolic and a.addr.ast.depth > 1:
            simple_accesses = False
        if a.addr.ast.concrete:
            if a.action == ANGR_READ:
                # allow silently reads on the stack in a range [init.sp-Arch.STACK_CELLS, init.sp+Arch.STACK_CELLS], that anyway probably won't be useful
                constraints = False
                constraints = claripy.Or(
                    constraints, (a.addr.ast - init_state.regs.sp) >
                    (Arch.STACK_CELLS * (Arch.ARCH_BITS // 8)))
                # Note: < and > are unsigned by default in claripy
                constraints = claripy.Or(
                    constraints,
                    claripy.SLT(a.addr.ast - init_state.regs.sp,
                                -(Arch.STACK_CELLS * (Arch.ARCH_BITS // 8))))
                if final_state.satisfiable(extra_constraints=[constraints]):
                    mem.add(Arch.UnknownType.unknown)
                    simple_accesses = False
            elif a.action == ANGR_WRITE:
                # check if may write fixed memory outside the reserved area for the gadget on the stack
                constraints = False
                # outside or on the ret address
                constraints = claripy.Or(
                    constraints,
                    a.addr.ast - init_state.regs.sp >= g.stack_fix -
                    (Arch.ARCH_BITS // 8))
                # before init of the gadget
                constraints = claripy.Or(constraints,
                                         a.addr.ast - init_state.regs.sp < 0)
                if final_state.satisfiable(extra_constraints=[constraints]):
                    mem.add(Arch.UnknownType.unknown)
                    simple_accesses = False
    return (frozenset(mem), simple_accesses)
예제 #7
0
    def claripy_ast_from_ail_condition(self, condition) -> claripy.ast.Base:

        # Unpack a condition all the way to the leaves
        if isinstance(condition, claripy.ast.Base):  # pylint:disable=isinstance-second-argument-not-valid-type
            return condition

        def _op_with_unified_size(op, conv, operand0, operand1):
            # ensure operand1 is of the same size as operand0
            if isinstance(operand1, ailment.Expr.Const):
                # amazing - we do the eazy thing here
                return op(conv(operand0), operand1.value)
            if operand1.bits == operand0.bits:
                return op(conv(operand0), conv(operand1))
            # extension is required
            assert operand1.bits < operand0.bits
            operand1 = ailment.Expr.Convert(None, operand1.bits, operand0.bits, False, operand1)
            return op(conv(operand0), conv(operand1))

        _mapping = {
            'LogicalAnd': lambda expr, conv: claripy.And(conv(expr.operands[0]), conv(expr.operands[1])),
            'LogicalOr': lambda expr, conv: claripy.Or(conv(expr.operands[0]), conv(expr.operands[1])),
            'CmpEQ': lambda expr, conv: conv(expr.operands[0]) == conv(expr.operands[1]),
            'CmpNE': lambda expr, conv: conv(expr.operands[0]) != conv(expr.operands[1]),
            'CmpLE': lambda expr, conv: conv(expr.operands[0]) <= conv(expr.operands[1]),
            'CmpLEs': lambda expr, conv: claripy.SLE(conv(expr.operands[0]), conv(expr.operands[1])),
            'CmpLT': lambda expr, conv: conv(expr.operands[0]) < conv(expr.operands[1]),
            'CmpLTs': lambda expr, conv: claripy.SLT(conv(expr.operands[0]), conv(expr.operands[1])),
            'CmpGE': lambda expr, conv: conv(expr.operands[0]) >= conv(expr.operands[1]),
            'CmpGEs': lambda expr, conv: claripy.SGE(conv(expr.operands[0]), conv(expr.operands[1])),
            'CmpGT': lambda expr, conv: conv(expr.operands[0]) > conv(expr.operands[1]),
            'CmpGTs': lambda expr, conv: claripy.SGT(conv(expr.operands[0]), conv(expr.operands[1])),
            'Add': lambda expr, conv: conv(expr.operands[0]) + conv(expr.operands[1]),
            'Sub': lambda expr, conv: conv(expr.operands[0]) - conv(expr.operands[1]),
            'Mul': lambda expr, conv: conv(expr.operands[0]) * conv(expr.operands[1]),
            'Not': lambda expr, conv: claripy.Not(conv(expr.operand)),
            'Xor': lambda expr, conv: conv(expr.operands[0]) ^ conv(expr.operands[1]),
            'And': lambda expr, conv: conv(expr.operands[0]) & conv(expr.operands[1]),
            'Or': lambda expr, conv: conv(expr.operands[0]) | conv(expr.operands[1]),
            'Shr': lambda expr, conv: _op_with_unified_size(claripy.LShR, conv, expr.operands[0], expr.operands[1]),
            'Shl': lambda expr, conv: _op_with_unified_size(operator.lshift, conv, expr.operands[0], expr.operands[1]),
            'Sar': lambda expr, conv: _op_with_unified_size(operator.rshift, conv, expr.operands[0], expr.operands[1]),
        }

        if isinstance(condition, (ailment.Expr.Load, ailment.Expr.DirtyExpression, ailment.Expr.BasePointerOffset,
                                  ailment.Expr.ITE, ailment.Stmt.Call)):
            var = claripy.BVS('ailexpr_%s' % repr(condition), condition.bits, explicit_name=True)
            self._condition_mapping[var.args[0]] = condition
            return var
        elif isinstance(condition, ailment.Expr.Register):
            var = claripy.BVS('ailexpr_%s-%d' % (repr(condition), condition.idx), condition.bits, explicit_name=True)
            self._condition_mapping[var.args[0]] = condition
            return var
        elif isinstance(condition, ailment.Expr.Convert):
            # convert is special. if it generates a 1-bit variable, it should be treated as a BVS
            if condition.to_bits == 1:
                var_ = self.claripy_ast_from_ail_condition(condition.operands[0])
                name = 'ailcond_Conv(%d->%d, %s)' % (condition.from_bits, condition.to_bits, repr(var_))
                var = claripy.BoolS(name, explicit_name=True)
            else:
                var_ = self.claripy_ast_from_ail_condition(condition.operands[0])
                name = 'ailexpr_Conv(%d->%d, %s)' % (condition.from_bits, condition.to_bits, repr(var_))
                var = claripy.BVS(name, condition.to_bits, explicit_name=True)
            self._condition_mapping[var.args[0]] = condition
            return var
        elif isinstance(condition, ailment.Expr.Const):
            var = claripy.BVV(condition.value, condition.bits)
            return var
        elif isinstance(condition, ailment.Expr.Tmp):
            l.warning("Left-over ailment.Tmp variable %s.", condition)
            if condition.bits == 1:
                var = claripy.BoolV('ailtmp_%d' % condition.tmp_idx)
            else:
                var = claripy.BVS('ailtmp_%d' % condition.tmp_idx, condition.bits, explicit_name=True)
            self._condition_mapping[var.args[0]] = condition
            return var

        lambda_expr = _mapping.get(condition.verbose_op, None)
        if lambda_expr is None:
            raise NotImplementedError("Unsupported AIL expression operation %s. Consider implementing." % condition.op)
        r = lambda_expr(condition, self.claripy_ast_from_ail_condition)
        if r is NotImplemented:
            r = claripy.BVS("ailexpr_%r" % condition, condition.bits, explicit_name=True)
            self._condition_mapping[r.args[0]] = condition
        else:
            # don't lose tags
            r = r.annotate(TagsAnnotation(**condition.tags))
        return r
예제 #8
0
    def interpret(self, startpos, args, addr=None, simfd=None):
        """
        implement scanf - extract formatted data from memory or a file according to the stored format
        specifiers and store them into the pointers extracted from `args`.

        :param startpos:    The index of the first argument corresponding to the first format element
        :param args:        A function which, given the index of an argument to the function, returns that argument
        :param addr:        The address in the memory to extract data from, or...
        :param simfd:       A file descriptor to use for reading data from
        :return:            The number of arguments parsed
        """
        if simfd is not None and isinstance(simfd.read_storage, SimPackets):
            argnum = startpos
            for component in self.components:
                if type(component) is bytes:
                    sdata, _ = simfd.read_data(len(component),
                                               short_reads=False)
                    self.state.solver.add(sdata == component)
                elif isinstance(component, claripy.Bits):
                    sdata, _ = simfd.read_data(len(component) // 8,
                                               short_reads=False)
                    self.state.solver.add(sdata == component)
                elif component.spec_type == b's':
                    if component.length_spec is None:
                        sdata, slen = simfd.read_data(
                            self.state.libc.buf_symbolic_bytes)
                    else:
                        sdata, slen = simfd.read_data(component.length_spec)
                    for byte in sdata.chop(8):
                        self.state.solver.add(
                            claripy.And(*[
                                byte != char for char in self.SCANF_DELIMITERS
                            ]))
                    self.state.memory.store(args(argnum), sdata, size=slen)
                    self.state.memory.store(
                        args(argnum) + slen, claripy.BVV(0, 8))
                    argnum += 1
                elif component.spec_type == b'c':
                    sdata, _ = simfd.read_data(1, short_reads=False)
                    self.state.memory.store(args(argnum), sdata)
                    argnum += 1
                else:
                    bits = component.size * 8
                    if component.spec_type == b'x':
                        base = 16
                    elif component.spec_type == b'o':
                        base = 8
                    else:
                        base = 10

                    # here's the variable representing the result of the parsing
                    target_variable = self.state.solver.BVS(
                        'scanf_' + component.string.decode(),
                        bits,
                        key=('api', 'scanf', argnum - startpos,
                             component.string))
                    negative = claripy.SLT(target_variable, 0)

                    # how many digits does it take to represent this variable fully?
                    max_digits = int(math.ceil(math.log(2**bits, base)))

                    # how many digits does the format specify?
                    spec_digits = component.length_spec

                    # how many bits can we specify as input?
                    available_bits = float(
                        'inf'
                    ) if spec_digits is None else spec_digits * math.log(
                        base, 2)
                    not_enough_bits = available_bits < bits

                    # how many digits will we model this input as?
                    digits = max_digits if spec_digits is None else spec_digits

                    # constrain target variable range explicitly if it can't take on all possible values
                    if not_enough_bits:
                        self.state.solver.add(
                            self.state.solver.And(
                                self.state.solver.SLE(target_variable,
                                                      (base**digits) - 1),
                                self.state.solver.SGE(
                                    target_variable,
                                    -(base**(digits - 1) - 1))))

                    # perform the parsing in reverse - constrain the input digits to be the string version of the input
                    # this only works because we're reading from a packet stream and therefore nobody has the ability
                    # to add other constraints to this data!
                    # this makes z3's job EXTREMELY easy
                    sdata, _ = simfd.read_data(digits, short_reads=False)
                    for i, digit in enumerate(reversed(sdata.chop(8))):
                        digit_value = (target_variable // (base**i)) % base
                        digit_ascii = digit_value + ord('0')
                        if base > 10:
                            digit_ascii = claripy.If(
                                digit_value >= 10,
                                digit_value + (-10 + ord('a')), digit_ascii)

                        # if there aren't enough bits, we can increase the range by accounting for the possibility that
                        # the first digit is a minus sign
                        if not_enough_bits:
                            if i == digits - 1:
                                neg_digit_ascii = ord('-')
                            else:
                                neg_digit_value = (-target_variable //
                                                   (base**i)) % base
                                neg_digit_ascii = neg_digit_value + ord('0')
                                if base > 10:
                                    neg_digit_ascii = claripy.If(
                                        neg_digit_value >= 10,
                                        neg_digit_value + (-10 + ord('a')),
                                        neg_digit_ascii)

                            digit_ascii = claripy.If(negative, neg_digit_ascii,
                                                     digit_ascii)

                        self.state.solver.add(digit == digit_ascii[7:0])

                    self.state.memory.store(
                        args(argnum),
                        target_variable,
                        endness=self.state.arch.memory_endness)
                    argnum += 1

            return argnum - startpos

        # TODO: we only support one format specifier in interpretation for now

        format_specifier_count = sum(1 for x in self.components
                                     if isinstance(x, FormatSpecifier))
        if format_specifier_count > 1:
            l.warning(
                "We don't support more than one format specifiers in format strings."
            )

        if simfd is not None:
            region = simfd.read_storage
            addr = simfd._pos if hasattr(
                simfd, '_pos') else simfd._read_pos  # XXX THIS IS BAD
        else:
            region = self.parser.state.memory

        bits = self.parser.state.arch.bits
        failed = self.parser.state.solver.BVV(0, bits)
        argpos = startpos
        position = addr
        for component in self.components:
            if isinstance(component, bytes):
                # TODO we skip non-format-specifiers in format string interpretation for now
                # if the region doesn't match the concrete component, we need to return immediately
                pass
            else:
                fmt_spec = component
                try:
                    dest = args(argpos)
                except SimProcedureArgumentError:
                    dest = None
                if fmt_spec.spec_type == b's':
                    # set some limits for the find
                    max_str_len = self.parser.state.libc.max_str_len
                    max_sym_bytes = self.parser.state.libc.buf_symbolic_bytes

                    # has the length of the format been limited by the string itself?
                    if fmt_spec.length_spec is not None:
                        max_str_len = fmt_spec.length_spec
                        max_sym_bytes = fmt_spec.length_spec

                    # TODO: look for limits on other characters which scanf is sensitive to, '\x00', '\x20'
                    ohr, ohc, ohi = region.find(
                        position,
                        self.parser.state.solver.BVV(b'\n'),
                        max_str_len,
                        max_symbolic_bytes=max_sym_bytes)

                    # if no newline is found, mm is position + max_strlen
                    # If-branch will really only happen for format specifiers with a length
                    mm = self.parser.state.solver.If(ohr == 0,
                                                     position + max_str_len,
                                                     ohr)
                    # we're just going to concretize the length, load will do this anyways
                    length = self.parser.state.solver.max_int(mm - position)
                    src_str = region.load(position, length)

                    # TODO all of these should be delimiters we search for above
                    # add that the contents of the string cannot be any scanf %s string delimiters
                    for delimiter in set(FormatString.SCANF_DELIMITERS):
                        delim_bvv = self.parser.state.solver.BVV(delimiter)
                        for i in range(length):
                            self.parser.state.add_constraints(
                                region.load(position + i, 1) != delim_bvv)

                    # write it out to the pointer
                    self.parser.state.memory.store(dest, src_str)
                    # store the terminating null byte
                    self.parser.state.memory.store(
                        dest + length, self.parser.state.solver.BVV(0, 8))

                    position += length

                else:

                    # XXX: atoi only supports strings of one byte
                    if fmt_spec.spec_type in [b'd', b'i', b'u', b'x']:
                        base = 16 if fmt_spec.spec_type == b'x' else 10
                        status, i, num_bytes = self.parser._sim_atoi_inner(
                            position,
                            region,
                            base=base,
                            read_length=fmt_spec.length_spec)
                        # increase failed count if we were unable to parse it
                        failed = self.parser.state.solver.If(
                            status, failed, failed + 1)
                        position += num_bytes
                    elif fmt_spec.spec_type == b'c':
                        i = region.load(position, 1)
                        i = i.zero_extend(bits - 8)
                        position += 1
                    else:
                        raise SimProcedureError(
                            "unsupported format spec '%s' in interpret" %
                            fmt_spec.spec_type)

                    i = self.parser.state.solver.Extract(
                        fmt_spec.size * 8 - 1, 0, i)
                    self.parser.state.memory.store(
                        dest,
                        i,
                        size=fmt_spec.size,
                        endness=self.parser.state.arch.memory_endness)

                argpos += 1

        if simfd is not None:
            _, realsize = simfd.read_data(position - addr)
            self.state.solver.add(realsize == position - addr)

        return (argpos - startpos) - failed
예제 #9
0
    def constrain_variables(self, func, solver, stack):
        self.offsets = self.funcdata[func.addr]
        self.bounds_marked = set()
        self.stack = stack
        self.solver = solver

        # do some sanity checking first
        top = min(self.offsets)
        for addr in stack.addr_list:
            if addr < top:
                raise Exception("Provided vars miss an access (off the top!)")
            base_addr = addr
            while base_addr not in self.offsets:
                base_addr -= 1
            this_offset = addr - base_addr
            if this_offset >= self.offsets[base_addr][0]:
                raise Exception(
                    "Provided vars miss an access (between the cracks!)")

        i = 0
        while i < len(stack.addr_list):
            addr = stack.addr_list[i]
            if addr in self.offsets:
                if i != 0 and self.offsets[stack.addr_list[
                        i - 1]][0] + stack.addr_list[i - 1] > addr:
                    raise Exception("Provided vars have an overlap!")
                i += 1
                continue
            stack.merge_up(i)

        # standard stuff
        stack.alloc_op.apply_constraints(solver)
        solver.add(stack.alloc_op.symval == -stack.sym_size)
        for op in stack.dealloc_ops:
            op.apply_constraints(solver)
            solver.add(op.symval == 0)

        solver.add(stack.sym_size % stack.arch.bytes == 0)
        solver.add(claripy.SGE(stack.sym_size, stack.conc_size))
        stack.unsafe_constraints.append(
            claripy.SGT(stack.sym_size, stack.conc_size))
        stack.unsafe_constraints.append(
            claripy.SGE(stack.sym_size, stack.conc_size * 2))
        stack.unsafe_constraints.append(
            claripy.SLT(stack.sym_size, stack.conc_size * 3))

        # loop through variables, add the important constraints!
        i = 0
        while i < len(stack.addr_list):
            addr = stack.addr_list[i]
            var = stack.variables[addr]
            var.size = self.offsets[addr][0]
            fix = self.offsets[addr][1]
            if fix == 'TOP':
                var.special_top = True
            elif fix == 'BOTTOM':
                var.special_bottom = True

            align = self.offsets[addr][2]
            if align != 1:
                solver.add(var.sym_addr % align == 0)
            var.sym_link(
                solver,
                stack)  # this hooks up the constrains to actual immediates
            # also the top/bottom fixing happens in there

            if i != 0:
                prev_var = stack.variables[stack.addr_list[i - 1]]
                self.mark_boundaries(prev_var, var)
            if i != len(stack.addr_list) - 1:
                next_var = stack.variables[stack.addr_list[i + 1]]
                self.mark_boundaries(var, next_var)

                # ew. ew ew ew ew ew ew!!!
                diff = next_var.conc_addr - var.conc_addr
                solver.add(claripy.SLT(var.sym_addr, var.sym_addr + diff))
            if i == 0:
                solver.add(claripy.SLE(-stack.sym_size, var.sym_addr))

            i += 1
예제 #10
0
    def claripy_ast_from_ail_condition(self, condition):

        # Unpack a condition all the way to the leaves
        if isinstance(condition, claripy.ast.Base):
            return condition

        _mapping = {
            'LogicalAnd':
            lambda expr, conv: claripy.And(conv(expr.operands[0]),
                                           conv(expr.operands[1])),
            'LogicalOr':
            lambda expr, conv: claripy.Or(conv(expr.operands[0]),
                                          conv(expr.operands[1])),
            'CmpEQ':
            lambda expr, conv: conv(expr.operands[0]) == conv(expr.operands[1]
                                                              ),
            'CmpNE':
            lambda expr, conv: conv(expr.operands[0]) != conv(expr.operands[1]
                                                              ),
            'CmpLE':
            lambda expr, conv: conv(expr.operands[0]) <= conv(expr.operands[1]
                                                              ),
            'CmpLEs':
            lambda expr, conv: claripy.SLE(conv(expr.operands[0]),
                                           conv(expr.operands[1])),
            'CmpLT':
            lambda expr, conv: conv(expr.operands[0]) < conv(expr.operands[1]),
            'CmpLTs':
            lambda expr, conv: claripy.SLT(conv(expr.operands[0]),
                                           conv(expr.operands[1])),
            'CmpGE':
            lambda expr, conv: conv(expr.operands[0]) >= conv(expr.operands[1]
                                                              ),
            'CmpGEs':
            lambda expr, conv: claripy.SGE(conv(expr.operands[0]),
                                           conv(expr.operands[1])),
            'CmpGT':
            lambda expr, conv: conv(expr.operands[0]) > conv(expr.operands[1]),
            'CmpGTs':
            lambda expr, conv: claripy.SGT(conv(expr.operands[0]),
                                           conv(expr.operands[1])),
            'Add':
            lambda expr, conv: conv(expr.operands[0]) + conv(expr.operands[1]),
            'Sub':
            lambda expr, conv: conv(expr.operands[0]) - conv(expr.operands[1]),
            'Not':
            lambda expr, conv: claripy.Not(conv(expr.operand)),
            'Xor':
            lambda expr, conv: conv(expr.operands[0]) ^ conv(expr.operands[1]),
            'And':
            lambda expr, conv: conv(expr.operands[0]) & conv(expr.operands[1]),
            'Or':
            lambda expr, conv: conv(expr.operands[0]) | conv(expr.operands[1]),
            'Shr':
            lambda expr, conv: claripy.LShR(conv(expr.operands[0]), expr.
                                            operands[1].value)
        }

        if isinstance(condition,
                      (ailment.Expr.Load, ailment.Expr.DirtyExpression,
                       ailment.Expr.BasePointerOffset)):
            var = claripy.BVS('ailexpr_%s' % repr(condition),
                              condition.bits,
                              explicit_name=True)
            self._condition_mapping[var] = condition
            return var
        elif isinstance(condition, ailment.Expr.Register):
            var = claripy.BVS('ailexpr_%s-%d' %
                              (repr(condition), condition.idx),
                              condition.bits,
                              explicit_name=True)
            self._condition_mapping[var] = condition
            return var
        elif isinstance(condition, ailment.Expr.Convert):
            # convert is special. if it generates a 1-bit variable, it should be treated as a BVS
            if condition.to_bits == 1:
                var_ = self.claripy_ast_from_ail_condition(
                    condition.operands[0])
                name = 'ailcond_Conv(%d->%d, %s)' % (
                    condition.from_bits, condition.to_bits, repr(var_))
                var = claripy.BoolS(name, explicit_name=True)
            else:
                var_ = self.claripy_ast_from_ail_condition(
                    condition.operands[0])
                name = 'ailexpr_Conv(%d->%d, %s)' % (
                    condition.from_bits, condition.to_bits, repr(var_))
                var = claripy.BVS(name, condition.to_bits, explicit_name=True)
            self._condition_mapping[var] = condition
            return var
        elif isinstance(condition, ailment.Expr.Const):
            var = claripy.BVV(condition.value, condition.bits)
            return var
        elif isinstance(condition, ailment.Expr.Tmp):
            l.warning("Left-over ailment.Tmp variable %s.", condition)
            if condition.bits == 1:
                var = claripy.BoolV('ailtmp_%d' % condition.tmp_idx)
            else:
                var = claripy.BVS('ailtmp_%d' % condition.tmp_idx,
                                  condition.bits)
            self._condition_mapping[var] = condition
            return var

        lambda_expr = _mapping.get(condition.verbose_op, None)
        if lambda_expr is None:
            raise NotImplementedError(
                "Unsupported AIL expression operation %s. Consider implementing."
                % condition.op)
        r = lambda_expr(condition, self.claripy_ast_from_ail_condition)
        if r is NotImplemented:
            r = claripy.BVS("ailexpr_%r" % condition,
                            condition.bits,
                            explicit_name=True)
            self._condition_mapping[r] = condition
        return r
예제 #11
0
    def exec_branch(self, state):  # pylint:disable=invalid-name
        """Execute forward from a state, queuing new states if needed."""
        logger.debug("Constraints: %s", state.solver.constraints)

        def solution(variable):
            """Returns the solution. There must be one or we fail."""
            solutions = state.solver.eval(variable, 2)
            if len(solutions) > 1:
                raise MultipleSolutionsError(
                    "Multiple solutions for %s (%#x)" %
                    (variable, self.code[state.pc]))
            solution = solutions[0]
            return solution if isinstance(solution,
                                          numbers.Number) else solution.value

        self.code.pc = state.pc

        while True:
            if state.pc >= len(self.code):
                return True

            op = self.code[state.pc]
            self.code.pc += 1
            self.coverage[state.pc] += 1

            logger.debug("NEW STEP")
            logger.debug("Memory: %s", state.memory)
            logger.debug("Stack: %s", state.stack)
            logger.debug("PC: %i, %#x", state.pc, op)

            assert self.code.pc == state.pc + 1
            assert isinstance(op, numbers.Number)
            assert all(
                isinstance(i, claripy.ast.base.BV) for i in
                state.stack), "The stack musty only contains claripy BV's"

            # Trivial operations first
            if not self.code.is_valid_opcode(state.pc):
                raise utils.CodeError("Trying to execute PUSH data")
            elif op == 254:  # INVALID opcode
                raise utils.CodeError("designed INVALID opcode")
            elif op == opcode_values.JUMPDEST:
                pass
            elif op == opcode_values.ADD:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(s0 + s1)
            elif op == opcode_values.SUB:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(s0 - s1)
            elif op == opcode_values.MUL:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(s0 * s1)
            elif op == opcode_values.DIV:
                # We need to use claripy.LShR instead of a division if possible,
                # because the solver is bad dealing with divisions, better
                # with shifts. And we need shifts to handle the solidity ABI
                # for function selection.
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)  # pylint:disable=invalid-name
                except MultipleSolutionsError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0 / s1))
                else:
                    if s1 == 0:
                        state.stack_push(BVV_0)
                    elif s1 == 1:
                        state.stack_push(s0)
                    elif s1 & (s1 - 1) == 0:
                        exp = int(math.log(s1, 2))
                        state.stack_push(s0.LShR(exp))
                    else:
                        state.stack_push(s0 / s1)
            elif op == opcode_values.SDIV:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)
                except MultipleSolutionsError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0.SDiv(s1)))
                else:
                    state.stack_push(BVV_0 if s1 == 0 else s0.SDiv(s1))
            elif op == opcode_values.MOD:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)
                except MultipleSolutionsError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0 % s1))
                else:
                    state.stack_push(BVV_0 if s1 == 0 else s0 % s1)
            elif op == opcode_values.SMOD:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)
                except MultipleSolutionsError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0.SMod(s1)))
                else:
                    state.stack_push(BVV_0 if s1 == 0 else s0.SMod(s1))
            elif op == opcode_values.ADDMOD:
                s0, s1, s2 = state.stack_pop(), state.stack_pop(
                ), state.stack_pop()
                try:
                    s2 = solution(s2)
                except MultipleSolutionsError:
                    state.stack_push(claripy.If(s2 == 0, BVV_0,
                                                (s0 + s1) % s2))
                else:
                    state.stack_push(BVV_0 if s2 == 0 else (s0 + s1) % s2)
            elif op == opcode_values.MULMOD:
                s0, s1, s2 = state.stack_pop(), state.stack_pop(
                ), state.stack_pop()
                try:
                    s2 = solution(s2)
                except MultipleSolutionsError:
                    state.stack_push(claripy.If(s2 == 0, BVV_0,
                                                (s0 * s1) % s2))
                else:
                    state.stack_push(BVV_0 if s2 == 0 else (s0 * s1) % s2)
            elif op == opcode_values.SHL:
                shift, value = state.stack_pop(), state.stack_pop()
                state.stack_push(value << shift)
            elif op == opcode_values.SHR:
                shift, value = state.stack_pop(), state.stack_pop()
                state.stack_push(value.LShR(shift))
            elif op == opcode_values.SAR:
                shift, value = state.stack_pop(), state.stack_pop()
                state.stack_push(claripy.RotateRight(value, shift))
            elif op == opcode_values.EXP:
                base, exponent = state.stack_pop(), state.stack_pop()
                base_sol = solution(base)
                if base_sol == 2:
                    state.stack_push(1 << exponent)
                else:
                    try:
                        exponent_sol = solution(exponent)
                    except MultipleSolutionsError:
                        state.stack_push(exponent)  # restore stack
                        state.stack_push(base)
                        self.add_for_fuzzing(state, exponent,
                                             EXP_EXPONENT_FUZZ)
                        return False
                    else:
                        state.stack_push(
                            claripy.BVV(base_sol**exponent_sol, 256))
            elif op == opcode_values.LT:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(bool_to_bv(claripy.ULT(s0, s1)))
            elif op == opcode_values.GT:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(bool_to_bv(claripy.UGT(s0, s1)))
            elif op == opcode_values.SLT:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(bool_to_bv(claripy.SLT(s0, s1)))
            elif op == opcode_values.SGT:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(bool_to_bv(claripy.SGT(s0, s1)))
            elif op == opcode_values.SIGNEXTEND:
                # TODO: Use Claripy's SignExt that should do exactly that.
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                # s0 is the number of bits. s1 the number we want to extend.
                s0 = solution(s0)
                if s0 <= 31:
                    sign_bit = 1 << (s0 * 8 + 7)
                    state.stack_push(
                        claripy.If(
                            s1 & sign_bit == 0,
                            s1 & (sign_bit - 1),
                            s1 | ((1 << 256) - sign_bit),
                        ))
                else:
                    state.stack_push(s1)
            elif op == opcode_values.EQ:
                s0, s1 = state.stack_pop(), state.stack_pop()
                state.stack_push(bool_to_bv(s0 == s1))
            elif op == opcode_values.ISZERO:
                state.stack_push(bool_to_bv(state.stack_pop() == BVV_0))
            elif op == opcode_values.AND:
                s0, s1 = state.stack_pop(), state.stack_pop()
                state.stack_push(s0 & s1)
            elif op == opcode_values.OR:
                s0, s1 = state.stack_pop(), state.stack_pop()
                state.stack_push(s0 | s1)
            elif op == opcode_values.XOR:
                s0, s1 = state.stack_pop(), state.stack_pop()
                state.stack_push(s0 ^ s1)
            elif op == opcode_values.NOT:
                state.stack_push(~state.stack_pop())
            elif op == opcode_values.BYTE:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(
                    s1.LShR(claripy.If(s0 > 31, 32, 31 - s0) * 8) & 0xFF)

            elif op == opcode_values.PC:
                state.stack_push(bvv(state.pc))
            elif op == opcode_values.GAS:
                state.stack_push(state.env.gas)
            elif op == opcode_values.ADDRESS:
                state.stack_push(state.env.address)
            elif op == opcode_values.BALANCE:
                addr = solution(state.stack_pop())
                if addr != solution(state.env.address):
                    raise utils.InterpreterError(
                        state,
                        "Can only query balance of the current contract for now"
                    )
                state.stack_push(state.env.balance)
            elif op == opcode_values.ORIGIN:
                state.stack_push(state.env.origin)
            elif op == opcode_values.CALLER:
                state.stack_push(state.env.caller)
            elif op == opcode_values.CALLVALUE:
                state.stack_push(state.env.value)
            elif op == opcode_values.BLOCKHASH:
                block_num = state.stack_pop()
                if block_num not in state.env.block_hashes:
                    state.env.block_hashes[block_num] = claripy.BVS(
                        "blockhash[%s]" % block_num, 256)
                state.stack_push(state.env.block_hashes[block_num])
            elif op == opcode_values.TIMESTAMP:
                state.stack_push(state.env.block_timestamp)
            elif op == opcode_values.NUMBER:
                state.stack_push(state.env.block_number)
            elif op == opcode_values.COINBASE:
                state.stack_push(state.env.coinbase)
            elif op == opcode_values.DIFFICULTY:
                state.stack_push(state.env.difficulty)
            elif op == opcode_values.POP:
                state.stack_pop()
            elif op == opcode_values.JUMP:
                addr = solution(state.stack_pop())
                if addr >= len(self.code
                               ) or self.code[addr] != opcode_values.JUMPDEST:
                    raise utils.CodeError("Invalid jump (%i)" % addr)
                state.pc = addr
                self.add_branch(state)
                return False
            elif op == opcode_values.JUMPI:
                addr, condition = solution(
                    state.stack_pop()), state.stack_pop()
                state_false = state.copy()
                state.solver.add(condition != BVV_0)
                state_false.solver.add(condition == BVV_0)
                state_false.pc += 1
                self.add_branch(state_false)
                state.pc = addr
                if (state.pc >= len(self.code)
                        or self.code[state.pc] != opcode_values.JUMPDEST):
                    raise utils.CodeError("Invalid jump (%i)" % (state.pc - 1))
                self.add_branch(state)
                return False
            elif opcode_values.PUSH1 <= op <= opcode_values.PUSH32:
                pushnum = op - opcode_values.PUSH1 + 1
                raw_value = self.code.read(pushnum)
                state.pc += pushnum
                state.stack_push(
                    bvv(int.from_bytes(raw_value, byteorder="big")))
            elif opcode_values.DUP1 <= op <= opcode_values.DUP16:
                depth = op - opcode_values.DUP1 + 1
                state.stack_push(state.stack[-depth])
            elif opcode_values.SWAP1 <= op <= opcode_values.SWAP16:
                depth = op - opcode_values.SWAP1 + 1
                temp = state.stack[-depth - 1]
                state.stack[-depth - 1] = state.stack[-1]
                state.stack[-1] = temp
            elif opcode_values.LOG0 <= op <= opcode_values.LOG4:
                depth = op - opcode_values.LOG0
                mstart, msz = (state.stack_pop(), state.stack_pop())
                topics = [state.stack_pop() for x in range(depth)]
            elif op == opcode_values.SHA3:
                start, length = solution(state.stack_pop()), solution(
                    state.stack_pop())
                memory = state.memory.read(start, length)
                state.stack_push(Sha3(memory))
            elif op == opcode_values.STOP:
                return True
            elif op == opcode_values.RETURN:
                return True

            elif op == opcode_values.CALLDATALOAD:
                index = state.stack_pop()
                try:
                    index_sol = solution(index)
                except MultipleSolutionsError:
                    state.stack_push(index)  # restore the stack
                    self.add_for_fuzzing(state, index, CALLDATALOAD_INDEX_FUZZ)
                    return False
                state.stack_push(state.env.calldata.read(index_sol, 32))
            elif op == opcode_values.CALLDATASIZE:
                state.stack_push(state.env.calldata_size)
            elif op == opcode_values.CALLDATACOPY:
                old_state = state.copy()
                mstart, dstart, size = (
                    state.stack_pop(),
                    state.stack_pop(),
                    state.stack_pop(),
                )
                mstart, dstart = solution(mstart), solution(dstart)
                try:
                    size = solution(size)
                except MultipleSolutionsError:
                    self.add_for_fuzzing(old_state, size,
                                         CALLDATACOPY_SIZE_FUZZ)
                    return False
                state.memory.copy_from(state.env.calldata, mstart, dstart,
                                       size)
            elif op == opcode_values.CODESIZE:
                state.stack_push(bvv(len(self.code)))
            elif op == opcode_values.EXTCODESIZE:
                addr = state.stack_pop()
                if (addr == state.env.address).is_true():
                    state.stack_push(bvv(len(self.code)))
                else:
                    # TODO: Improve that... It's clearly not constraining enough.
                    state.stack_push(claripy.BVS("EXTCODESIZE[%s]" % addr,
                                                 256))

            elif op == opcode_values.EXTCODECOPY:
                old_state = state.copy()
                addr = state.stack_pop()
                mem_start = solution(state.stack_pop())
                code_start = solution(state.stack_pop())

                size = state.stack_pop()
                try:
                    size = solution(size)
                except MultipleSolutionsError:
                    # TODO: Fuzz.
                    # self.add_for_fuzzing(old_state, size, [])
                    # return False
                    raise
                state.memory.write(
                    mem_start,
                    size,
                    claripy.BVS("EXTCODE[%s from %s]" % (addr, code_start),
                                size * 8),
                )

            elif op == opcode_values.CODECOPY:
                mem_start, code_start, size = [
                    solution(state.stack_pop()) for _ in range(3)
                ]
                for i in range(size):
                    if code_start + i < len(state.env.code):
                        state.memory.write(
                            mem_start + i,
                            1,
                            claripy.BVV(state.env.code[code_start + i], 8),
                        )
                    else:
                        state.memory.write(mem_start + i, 1, claripy.BVV(0, 8))

            elif op == opcode_values.MLOAD:
                index = solution(state.stack_pop())
                state.stack_push(state.memory.read(index, 32))
            elif op == opcode_values.MSTORE:
                index, value = solution(state.stack_pop()), state.stack_pop()
                state.memory.write(index, 32, value)
            elif op == opcode_values.MSTORE8:
                index, value = solution(state.stack_pop()), state.stack_pop()
                state.memory.write(index, 1, value[7:0])
            elif op == opcode_values.MSIZE:
                state.stack_push(bvv(state.memory.size()))

            elif op == opcode_values.SLOAD:
                state.pc += 1
                key = state.stack_pop()
                for w_key, w_value in state.storage_written.items():
                    read_written = [w_key == key]
                    if state.solver.satisfiable(
                            extra_constraints=read_written):
                        new_state = state.copy()
                        new_state.solver.add(read_written)
                        new_state.stack_push(w_value)
                        self.add_branch(new_state)
                    state.solver.add(w_key != key)
                if state.solver.satisfiable():
                    assert key not in state.storage_written
                    if key not in state.storage_read:
                        state.storage_read[key] = claripy.BVS(
                            "storage[%s]" % key, 256)
                    state.stack_push(state.storage_read[key])
                    self.add_branch(state)
                return

            elif op == opcode_values.SSTORE:
                state.pc += 1
                key = state.stack_pop()
                value = state.stack_pop()
                for w_key, w_value in state.storage_written.items():
                    read_written = [w_key == key]
                    if state.solver.satisfiable(
                            extra_constraints=read_written):
                        new_state = state.copy()
                        new_state.solver.add(read_written)
                        new_state.storage_written[w_key] = value
                        self.add_branch(new_state)
                        state.solver.add(w_key != key)
                if state.solver.satisfiable():
                    assert key not in state.storage_written
                    state.storage_written[key] = value
                    self.add_branch(state)
                return

            elif op == opcode_values.CALL:
                state.pc += 1

                # pylint:disable=unused-variable
                gas, to_, value, meminstart, meminsz, memoutstart, memoutsz = (
                    state.stack_pop() for _ in range(7))

                # First possibility: the call fails
                # (always possible with a call stack big enough)
                state_fail = state.copy()
                state_fail.stack_push(BVV_0)
                self.add_branch(state_fail)

                # Second possibility: success.
                state.calls.append((memoutsz, memoutstart, meminsz, meminstart,
                                    value, to_, gas))

                memoutsz = solution(memoutsz)
                if memoutsz != 0:
                    # If we expect some output, let's constraint the call to
                    # be to a contract that we do control. Otherwise it could
                    # return anything...
                    state.solver.add(to_[159:0] == utils.DEFAULT_CALLER[159:0])

                    memoutstart = solution(memoutstart)
                    state.memory.write(
                        memoutstart,
                        memoutsz,
                        claripy.BVS("CALL_RETURN[%s]" % to_, memoutsz * 8),
                    )

                state.stack_push(BVV_1)
                self.add_branch(state)
                return False

            elif op == opcode_values.DELEGATECALL:
                state.pc += 1

                # pylint:disable=unused-variable
                gas, to_, meminstart, meminsz, memoutstart, memoutsz = (
                    state.stack_pop() for _ in range(6))

                # First possibility: the call fails
                # (always possible with a call stack big enough)
                state_fail = state.copy()
                state_fail.stack_push(BVV_0)
                self.add_branch(state_fail)

                # If the call is to a specific contract we don't control,
                # don't assume it could return anything, or even be successful.
                # So we say we need to be able to call an arbitrary contract.
                state.solver.add(to_[159:0] == utils.DEFAULT_CALLER[159:0])

                # Second possibility: success.
                state.calls.append(
                    (memoutsz, memoutstart, meminsz, meminstart, to_, gas))

                memoutsz = solution(memoutsz)
                if memoutsz != 0:
                    memoutstart = solution(memoutstart)
                    state.memory.write(
                        memoutstart,
                        memoutsz,
                        claripy.BVS("DELEGATECALL_RETURN[%s]" % to_,
                                    memoutsz * 8),
                    )

                state.stack_push(BVV_1)
                self.add_branch(state)
                return False

            elif op == opcode_values.RETURNDATASIZE:
                state.stack_push(claripy.BVS("RETURNDATASIZE", 256))

            elif op == opcode_values.RETURNDATACOPY:
                old_state = state.copy()
                mem_start_position = solution(state.stack_pop())
                returndata_start_position = solution(state.stack_pop())

                size = state.stack_pop()
                try:
                    size = solution(size)
                except MultipleSolutionsError:
                    self.add_for_fuzzing(old_state, size,
                                         RETURNDATACOPY_SIZE_FUZZ)
                    return False

                state.memory.write(mem_start_position, size,
                                   claripy.BVS("RETURNDATACOPY", size * 8))

            elif op == opcode_values.SELFDESTRUCT:
                state.selfdestruct_to = state.stack[-1]
                return True

            elif op == opcode_values.REVERT:
                return False

            else:
                raise utils.InterpreterError(state, "Unknown opcode %#x" % op)

            state.pc += 1
예제 #12
0
    def exec_branch(self, state):  # pylint:disable=invalid-name
        """Execute forward from a state, queuing new states if needed."""
        logger.debug("Constraints: %s", state.solver.constraints)

        def solution(variable):
            """Returns the solution. There must be one or we fail."""
            solutions = state.solver.eval(variable, 2)
            if len(solutions) > 1:
                raise ValueError("Ambiguous solution for %s (%s)" %
                                 (variable, self.code[state.pc]))
            solution = solutions[0]
            return solution if isinstance(solution,
                                          numbers.Number) else solution.value

        state.score += 1
        self.code.pc = state.pc

        while True:
            if state.pc >= len(self.code):
                return True

            op = self.code.next()
            self.coverage[state.pc] += 1

            logger.debug("NEW STEP")
            logger.debug("Memory: %s", state.memory)
            logger.debug("Stack: %s", state.stack)
            logger.debug("PC: %i, %s", state.pc, op)

            assert self.code.pc == state.pc + 1
            assert isinstance(op, numbers.Number)
            assert all(
                hasattr(i, "symbolic") for i in
                state.stack), "The stack musty only contains claripy BV's"

            # Trivial operations first
            if not self.code.is_valid_opcode(state.pc):
                raise utils.CodeError("Trying to execute PUSH data")
            elif op == 254:  # INVALID opcode
                raise utils.CodeError("designed INVALID opcode")
            elif op == opcode_values.JUMPDEST:
                pass
            elif op == opcode_values.ADD:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(s0 + s1)
            elif op == opcode_values.SUB:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(s0 - s1)
            elif op == opcode_values.MUL:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(s0 * s1)
            elif op == opcode_values.DIV:
                # We need to use claripy.LShR instead of a division if possible,
                # because the solver is bad dealing with divisions, better
                # with shifts. And we need shifts to handle the solidity ABI
                # for function selection.
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)  # pylint:disable=invalid-name
                except ValueError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0 / s1))
                else:
                    if s1 == 0:
                        state.stack_push(BVV_0)
                    elif s1 == 1:
                        state.stack_push(s0)
                    elif s1 & (s1 - 1) == 0:
                        exp = int(math.log(s1, 2))
                        state.stack_push(s0.LShR(exp))
                    else:
                        state.stack_push(s0 / s1)
            elif op == opcode_values.SDIV:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)
                except ValueError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0.SDiv(s1)))
                else:
                    state.stack_push(BVV_0 if s1 == 0 else s0.SDiv(s1))
            elif op == opcode_values.MOD:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)
                except ValueError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0 % s1))
                else:
                    state.stack_push(BVV_0 if s1 == 0 else s0 % s1)
            elif op == opcode_values.SMOD:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)
                except ValueError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0.SMod(s1)))
                else:
                    state.stack_push(BVV_0 if s1 == 0 else s0.SMod(s1))
            elif op == opcode_values.ADDMOD:
                s0, s1, s2 = state.stack_pop(), state.stack_pop(
                ), state.stack_pop()
                try:
                    s2 = solution(s2)
                except ValueError:
                    state.stack_push(claripy.If(s2 == 0, BVV_0,
                                                (s0 + s1) % s2))
                else:
                    state.stack_push(BVV_0 if s2 == 0 else (s0 + s1) % s2)
            elif op == opcode_values.MULMOD:
                s0, s1, s2 = state.stack_pop(), state.stack_pop(
                ), state.stack_pop()
                try:
                    s2 = solution(s2)
                except ValueError:
                    state.stack_push(claripy.If(s2 == 0, BVV_0,
                                                (s0 * s1) % s2))
                else:
                    state.stack_push(BVV_0 if s2 == 0 else (s0 * s1) % s2)
            elif op == opcode_values.EXP:
                base, exponent = solution(state.stack_pop()), state.stack_pop()
                if base == 2:
                    state.stack_push(1 << exponent)
                else:
                    exponent = solution(exponent)
                    state.stack_push(claripy.BVV(base**exponent, 256))
            elif op == opcode_values.LT:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(claripy.ULT(s0, s1))
            elif op == opcode_values.GT:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(claripy.UGT(s0, s1))
            elif op == opcode_values.SLT:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(claripy.SLT(s0, s1))
            elif op == opcode_values.SGT:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(claripy.SGT(s0, s1))
            elif op == opcode_values.SIGNEXTEND:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                # s0 is the number of bits. s1 the number we want to extend.
                s0 = solution(s0)
                if s0 <= 31:
                    sign_bit = 1 << (s0 * 8 + 7)
                    state.stack_push(
                        claripy.If(
                            s1 & sign_bit == 0,
                            s1 & (sign_bit - 1),
                            s1 | ((1 << 256) - sign_bit),
                        ))
                else:
                    state.stack_push(s1)
            elif op == opcode_values.EQ:
                s0, s1 = state.stack_pop(), state.stack_pop()
                if isinstance(s0, claripy.ast.Bool) and isinstance(
                        s1, claripy.ast.Bool):
                    state.stack_push(s0 == s1)
                else:
                    state.stack_push(not_bool(s0) == not_bool(s1))
            elif op == opcode_values.ISZERO:
                condition = state.stack_pop()
                if isinstance(condition, claripy.ast.Bool):
                    state.stack_push(claripy.Not(condition))
                else:
                    state.stack_push(condition == BVV_0)
            elif op == opcode_values.AND:
                s0, s1 = make_consistent(state.stack_pop(), state.stack_pop())
                if isinstance(s0, claripy.ast.Bool) and isinstance(
                        s1, claripy.ast.Bool):
                    state.stack_push(s0 and s1)
                else:
                    state.stack_push(s0 & s1)
            elif op == opcode_values.OR:
                s0, s1 = make_consistent(state.stack_pop(), state.stack_pop())
                if isinstance(s0, claripy.ast.Bool) and isinstance(
                        s1, claripy.ast.Bool):
                    state.stack_push(s0 or s1)
                else:
                    state.stack_push(s0 | s1)
            elif op == opcode_values.XOR:
                s0, s1 = make_consistent(state.stack_pop(), state.stack_pop())
                state.stack_push(s0 ^ s1)
            elif op == opcode_values.NOT:
                state.stack_push(~state.stack_pop())
            elif op == opcode_values.BYTE:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(
                    s1.LShR(claripy.If(s0 > 31, 32, 31 - s0) * 8) & 0xFF)

            elif op == opcode_values.PC:
                state.stack_push(bvv(state.pc))
            elif op == opcode_values.GAS:
                state.stack_push(state.env.gas)
            elif op == opcode_values.ADDRESS:
                state.stack_push(state.env.address)
            elif op == opcode_values.BALANCE:
                addr = solution(state.stack_pop())
                if addr != solution(state.env.address):
                    raise utils.InterpreterError(
                        state,
                        "Can only query balance of the current contract for now"
                    )
                state.stack_push(state.env.balance)
            elif op == opcode_values.ORIGIN:
                state.stack_push(state.env.origin)
            elif op == opcode_values.CALLER:
                state.stack_push(state.env.caller)
            elif op == opcode_values.CALLVALUE:
                state.stack_push(state.env.value)
            elif op == opcode_values.BLOCKHASH:
                block_num = state.stack_pop()
                if block_num not in state.env.block_hashes:
                    state.env.block_hashes[block_num] = claripy.BVS(
                        "blockhash[%s]" % block_num, 256)
                state.stack_push(state.env.block_hashes[block_num])
            elif op == opcode_values.TIMESTAMP:
                state.stack_push(state.env.block_timestamp)
            elif op == opcode_values.NUMBER:
                state.stack_push(state.env.block_number)
            elif op == opcode_values.COINBASE:
                state.stack_push(state.env.coinbase)
            elif op == opcode_values.DIFFICULTY:
                state.stack_push(state.env.difficulty)
            elif op == opcode_values.POP:
                state.stack_pop()
            elif op == opcode_values.JUMP:
                addr = solution(state.stack_pop())
                if addr >= len(self.code
                               ) or self.code[addr] != opcode_values.JUMPDEST:
                    raise utils.CodeError("Invalid jump (%i)" % addr)
                state.pc = addr
                self.add_branch(state)
                return False
            elif op == opcode_values.JUMPI:
                addr, condition = solution(
                    state.stack_pop()), state.stack_pop()
                state_false = state.copy()
                if isinstance(condition, claripy.ast.Bool):
                    state.solver.add(condition)
                    state_false.solver.add(claripy.Not(condition))
                else:
                    state.solver.add(condition != 0)
                    state_false.solver.add(condition == 0)
                state_false.pc += 1
                self.add_branch(state_false)
                state.pc = addr
                if (state.pc >= len(self.code)
                        or self.code[state.pc] != opcode_values.JUMPDEST):
                    raise utils.CodeError("Invalid jump (%i)" % (state.pc - 1))
                self.add_branch(state)
                return False
            elif opcode_values.PUSH1 <= op <= opcode_values.PUSH32:
                pushnum = op - opcode_values.PUSH1 + 1
                raw_value = self.code.read(pushnum)
                state.pc += pushnum
                state.stack_push(
                    bvv(int.from_bytes(raw_value, byteorder="big")))
            elif opcode_values.DUP1 <= op <= opcode_values.DUP16:
                depth = op - opcode_values.DUP1 + 1
                state.stack_push(state.stack[-depth])
            elif opcode_values.SWAP1 <= op <= opcode_values.SWAP16:
                depth = op - opcode_values.SWAP1 + 1
                temp = state.stack[-depth - 1]
                state.stack[-depth - 1] = state.stack[-1]
                state.stack[-1] = temp
            elif opcode_values.LOG0 <= op <= opcode_values.LOG4:
                depth = op - opcode_values.LOG0
                mstart, msz = (state.stack_pop(), state.stack_pop())
                topics = [state.stack_pop() for x in range(depth)]
            elif op == opcode_values.SHA3:
                start, length = solution(state.stack_pop()), solution(
                    state.stack_pop())
                memory = state.memory.read(start, length)
                state.stack_push(Sha3(memory))
            elif op == opcode_values.STOP:
                return True
            elif op == opcode_values.RETURN:
                return True

            elif op == opcode_values.CALLDATALOAD:
                indexes = state.stack_pop()
                try:
                    index = solution(indexes)
                except ValueError:  # Multiple solutions, let's fuzz.
                    state.stack_push(indexes)  # restore the stack
                    self.add_for_fuzzing(state, indexes, CALLDATASIZE_FUZZ)
                    return False
                state.solver.add(state.env.calldata_size >= index + 32)
                state.stack_push(state.env.calldata.read(index, 32))
            elif op == opcode_values.CALLDATASIZE:
                state.stack_push(state.env.calldata_size)
            elif op == opcode_values.CALLDATACOPY:
                old_state = state.copy()
                mstart, dstart, size = (
                    state.stack_pop(),
                    state.stack_pop(),
                    state.stack_pop(),
                )
                mstart, dstart = solution(mstart), solution(dstart)
                try:
                    size = solution(size)
                except ValueError:
                    self.add_for_fuzzing(old_state, size, CALLDATASIZE_FUZZ)
                    return False
                state.memory.copy_from(state.env.calldata, mstart, dstart,
                                       size)
                state.solver.add(state.env.calldata_size >= dstart + size)
            elif op == opcode_values.CODESIZE:
                state.stack_push(bvv(len(self.code)))
            elif op == opcode_values.EXTCODESIZE:
                addr = state.stack_pop()
                if (addr == state.env.address).is_true():
                    state.stack_push(bvv(len(self.code)))
                else:
                    # TODO: Improve that... It's clearly not constraining enough.
                    state.stack_push(claripy.BVS("EXTCODESIZE[%s]" % addr,
                                                 256))
            elif op == opcode_values.CODECOPY:
                mem_start, code_start, size = [
                    solution(state.stack_pop()) for _ in range(3)
                ]
                for i in range(size):
                    if code_start + i < len(state.env.code):
                        state.memory.write(
                            mem_start + i,
                            1,
                            claripy.BVV(state.env.code[code_start + i], 8),
                        )
                    else:
                        state.memory.write(mem_start + i, 1, claripy.BVV(0, 8))

            elif op == opcode_values.MLOAD:
                index = solution(state.stack_pop())
                state.stack_push(state.memory.read(index, 32))
            elif op == opcode_values.MSTORE:
                index, value = solution(state.stack_pop()), not_bool(
                    state.stack_pop())
                state.memory.write(index, 32, value)
            elif op == opcode_values.MSTORE8:
                index, value = solution(state.stack_pop()), not_bool(
                    state.stack_pop())
                state.memory.write(index, 1, value[7:0])
            elif op == opcode_values.MSIZE:
                state.stack_push(bvv(state.memory.size()))
            elif op == opcode_values.SLOAD:
                # TODO: This is inaccurate, because the storage can change
                # in a single transaction.
                # See commit d98cab834f8f359f01ef805256d179f5529ebe30.
                key = state.stack_pop()
                if key in state.storage_written:
                    state.stack_push(state.storage_written[key])
                else:
                    if key not in state.storage_read:
                        state.storage_read[key] = claripy.BVS(
                            "storage[%s]" % key, 256)
                    state.stack_push(state.storage_read[key])
            elif op == opcode_values.SSTORE:
                # TODO: This is inaccurate, because the storage can change
                # in a single transaction.
                # See commit d98cab834f8f359f01ef805256d179f5529ebe30.
                key = state.stack_pop()
                value = state.stack_pop()
                state.storage_written[key] = value

            elif op == opcode_values.CALL:
                state.pc += 1

                # First possibility: the call fails
                # (always possible with a call stack big enough)
                state_fail = state.copy()
                state_fail.stack_push(claripy.BoolV(False))
                self.add_branch(state_fail)

                # Second possibility: success.
                state.calls.append(state.stack[-7:])

                # pylint:disable=unused-variable
                gas, to_, value, meminstart, meminsz, memoutstart, memoutsz = (
                    state.stack_pop() for _ in range(7))

                if solution(memoutsz) != 0:
                    raise utils.InterpreterError(state,
                                                 "CALL seems to return data")
                if solution(meminsz) != 0:
                    raise utils.InterpreterError(state,
                                                 "CALL seems to take data")

                state.stack_push(claripy.BoolV(True))
                self.add_branch(state)
                return False

            elif op == opcode_values.SELFDESTRUCT:
                state.selfdestruct_to = state.stack[-1]
                return True

            elif op == opcode_values.REVERT:
                return False
            else:
                raise utils.InterpreterError(state, "Unknown opcode %s" % op)

            state.pc += 1
예제 #13
0
import settings
import claripy
settings.WARNING_ADDRESS = 0x1fcb0
settings.VERBOSE = True
settings.DEBUG = True
settings.TARGET_BINARY = "/home/stefan/Documents/Graduation/RISC-V-toolchain/riscv/Programs/Poly1305_onetimeauth/Radix2.26_woMultiplier/program.elf"
settings.TARGET_ADDRESS = 0x20011dfc
settings.A = claripy.BVS('A', 32)
settings.params = [settings.A]
settings.secret = settings.A
settings.constraints = [claripy.SLT(settings.A, 5), claripy.SGE(settings.A, 1)]
from pluginTime import TIME_STRATEGY_SHORTEST
settings.TIME_STRATEGY = TIME_STRATEGY_SHORTEST


def stateInit(startState):
    return True


settings.stateInit = stateInit
import tool
예제 #14
0
 def test_send_back_negative_signed(self):
     self.state.calls.append(self.get_call(self.env.calldata.read(0, 32)))
     self.state.solver.add(claripy.SLT(self.env.calldata.read(0, 32), 0))
     self.assertFalse(self.check_state(self.state))