示例#1
0
    def _translate_bsh(self, oprnd1, oprnd2, oprnd3):
        """Return a formula representation of a BSH instruction.
        """
        assert oprnd1.size and oprnd2.size and oprnd3.size
        assert oprnd1.size == oprnd2.size

        op1_var = self._translate_src_oprnd(oprnd1)
        op2_var = self._translate_src_oprnd(oprnd2)
        op3_var, _ = self._translate_dst_oprnd(oprnd3)

        if oprnd3.size > oprnd1.size:
            op1_var_zx = smtlibv2.ZEXTEND(op1_var, oprnd3.size)
            op2_var_zx = smtlibv2.ZEXTEND(op2_var, oprnd3.size)

            op2_var_neg = (-op2_var)
            op2_var_neg_sx = smtlibv2.SEXTEND(op2_var_neg, oprnd2.size,
                                              oprnd3.size)

            shl = smtlibv2.EXTRACT(op1_var_zx >> op2_var_neg_sx, 0,
                                   op3_var.size)
            shr = smtlibv2.EXTRACT(op1_var_zx << op2_var_zx, 0, op3_var.size)
        elif oprnd3.size < oprnd1.size:
            shl = smtlibv2.EXTRACT(op1_var >> (-op2_var), 0, op3_var.size)
            shr = smtlibv2.EXTRACT(op1_var << op2_var, 0, op3_var.size)
        else:
            shl = op1_var >> (-op2_var)
            shr = op1_var << op2_var

        return [(op3_var == smtlibv2.ITEBV(oprnd3.size, op2_var >= 0, shr,
                                           shl))]
示例#2
0
    def _translate_dst_register_oprnd(self, operand):
        """Translate destination resgister operand to SMT expr.
        """
        reg_info = self._arch_alias_mapper.get(operand.name, None)

        if reg_info:
            var_base_name, offset = reg_info

            old_var_name = self._get_var_name(var_base_name, fresh=False)

            var_name = self._get_var_name(var_base_name, fresh=True)
            var_size = self._arch_regs_size[var_base_name]

            ret_val = self._solver.mkBitVec(var_size, var_name)

            ret_val_cpy = ret_val

            ret_val = smtlibv2.EXTRACT(ret_val, offset, operand.size)

            old_ret_val = self._solver.mkBitVec(var_size, old_var_name)

            constrs = []

            if offset > 0 and offset < var_size - 1:
                lower_expr_1 = smtlibv2.EXTRACT(ret_val_cpy, 0, offset)
                lower_expr_2 = smtlibv2.EXTRACT(old_ret_val, 0, offset)

                constrs += [lower_expr_1 == lower_expr_2]

                upper_expr_1 = smtlibv2.EXTRACT(
                    ret_val_cpy, offset + operand.size,
                    var_size - offset - operand.size)
                upper_expr_2 = smtlibv2.EXTRACT(
                    old_ret_val, offset + operand.size,
                    var_size - offset - operand.size)

                constrs += [upper_expr_1 == upper_expr_2]
            elif offset == 0:
                upper_expr_1 = smtlibv2.EXTRACT(
                    ret_val_cpy, offset + operand.size,
                    var_size - offset - operand.size)
                upper_expr_2 = smtlibv2.EXTRACT(
                    old_ret_val, offset + operand.size,
                    var_size - offset - operand.size)

                constrs += [upper_expr_1 == upper_expr_2]
            elif offset == var_size - 1:
                lower_expr_1 = smtlibv2.EXTRACT(ret_val_cpy, 0, offset)
                lower_expr_2 = smtlibv2.EXTRACT(old_ret_val, 0, offset)

                constrs += [lower_expr_1 == lower_expr_2]

            parent_reg_constrs = constrs
        else:
            var_name = self._get_var_name(operand.name, fresh=True)
            ret_val = self._solver.mkBitVec(operand.size, var_name)

            parent_reg_constrs = None

        return ret_val, parent_reg_constrs
示例#3
0
    def _get_constrs_arithmetic_load(self, gadget):
        """Generate constraints for the ArithmeticLoad gadget: dst_reg <- dst_reg OP mem[src_reg + offset]
        """
        op = self._arithmetic_ops[gadget.operation]
        dst = self.analyzer.get_register_expr(gadget.destination[0].name,
                                              mode="post")
        size = gadget.destination[0].size

        if isinstance(gadget.sources[1], ReilRegisterOperand) and \
            not isinstance(gadget.sources[1], ReilEmptyOperand):
            base_addr = self.analyzer.get_register_expr(gadget.sources[1].name,
                                                        mode="pre")
            offset = self.analyzer.get_immediate_expr(
                gadget.sources[2].immediate, gadget.sources[2].size)

            addr = base_addr + offset
        else:
            addr = self.analyzer.get_immediate_expr(
                gadget.sources[2].immediate, gadget.sources[2].size)

        src1 = self.analyzer.get_register_expr(gadget.sources[0].name,
                                               mode="pre")
        src2 = self.analyzer.get_memory_expr(addr, size / 8)

        result = op(src1, src2)

        constrs = []

        for i in reversed(xrange(0, size, 8)):
            bytes_exprs_1 = smtlibv2.EXTRACT(result, i, 8)
            bytes_exprs_2 = smtlibv2.EXTRACT(dst, i, 8)

            constrs += [bytes_exprs_1 != bytes_exprs_2]

        # Check all non-modified registers don't change.
        constrs_mod = []

        for name in self._arch_info.registers_gp_base:
            if name not in [r.name for r in gadget.modified_registers]:
                var_initial = self.analyzer.get_register_expr(name, mode="pre")
                var_final = self.analyzer.get_register_expr(name, mode="post")

                constrs_mod += [var_initial != var_final]

        if constrs_mod:
            constrs_mod = [
                reduce(lambda c, acc: acc | c, constrs_mod[1:], constrs_mod[0])
            ]

        return constrs + constrs_mod
示例#4
0
    def _translate_bsh(self, oprnd1, oprnd2, oprnd3):
        """Return a formula representation of a BSH instruction.
        """
        assert oprnd1.size and oprnd2.size and oprnd3.size
        assert oprnd1.size == oprnd2.size

        op1_var = self._translate_src_oprnd(oprnd1)
        op2_var = self._translate_src_oprnd(oprnd2)
        op3_var, _ = self._translate_dst_oprnd(oprnd3)

        shl = smtlibv2.EXTRACT(op1_var >> (-op2_var), 0, op3_var.size)
        shr = smtlibv2.EXTRACT(op1_var << op2_var, 0, op3_var.size)

        return [(op3_var == smtlibv2.ITEBV(oprnd3.size, op2_var >= 0, shr,
                                           shl))]
示例#5
0
    def get_register_expr(self, register_name, mode="post"):
        """Return a smt bit vector that represents a register.
        """
        reg_info = self._arch_info.alias_mapper.get(register_name, None)

        if reg_info:
            var_base_name, offset = reg_info

            if mode == "pre":
                var_name = self._translator.get_init_name(var_base_name)
            elif mode == "post":
                var_name = self._translator.get_curr_name(var_base_name)
            else:
                raise Exception()

            var_size = self._arch_info.registers_size[var_base_name]

            ret_val = self._translator._solver.mkBitVec(var_size, var_name)
            ret_val = smtlibv2.EXTRACT(
                ret_val, offset, self._arch_info.registers_size[register_name])
        else:
            if mode == "pre":
                var_name = self._translator.get_init_name(register_name)
            elif mode == "post":
                var_name = self._translator.get_curr_name(register_name)
            else:
                raise Exception()

            var_size = self._arch_info.registers_size[register_name]

            ret_val = self._solver.mkBitVec(var_size, var_name)

        return ret_val
示例#6
0
    def _translate_stm(self, oprnd1, oprnd2, oprnd3):
        """Return a formula representation of a STM instruction.
        """
        assert oprnd1.size and oprnd3.size
        assert oprnd3.size == self._address_size

        op1_var = self._translate_src_oprnd(oprnd1)
        op3_var = self._translate_src_oprnd(oprnd3)

        where = op3_var
        size = oprnd1.size

        for i in xrange(0, size, 8):
            self._mem[where + i / 8] = smtlibv2.EXTRACT(op1_var, i, 8)

        # Memory versioning.
        self._mem_instance += 1

        mem_old = self._mem
        mem_new = self._solver.mkArray(self._address_size,
                                       "MEM_" + str(self._mem_instance))

        self._mem = mem_new

        return [mem_new == mem_old]
示例#7
0
    def _translate_ldm(self, oprnd1, oprnd2, oprnd3):
        """Return a formula representation of a LDM instruction.
        """
        assert oprnd1.size == self._address_size
        assert oprnd3.size

        op1_var = self._translate_src_oprnd(oprnd1)
        op3_var, parent_reg_constrs = self._translate_dst_oprnd(oprnd3)

        size = oprnd3.size
        where = op1_var

        exprs = []

        bytes_exprs = []
        bytes_exprs_2 = []
        for i in reversed(xrange(0, size, 8)):
            bytes_exprs_1 = smtlibv2.ord(self._mem[where + i / 8])
            bytes_exprs_2 = smtlibv2.EXTRACT(op3_var, i, 8)

            exprs += [bytes_exprs_1 == bytes_exprs_2]

        rv = exprs

        if parent_reg_constrs:
            rv += parent_reg_constrs

        return exprs
示例#8
0
    def _translate_xor(self, oprnd1, oprnd2, oprnd3):
        """Return a formula representation of a AND instruction.
        """
        assert oprnd1.size and oprnd2.size and oprnd3.size
        assert oprnd1.size == oprnd2.size

        op1_var = self._translate_src_oprnd(oprnd1)
        op2_var = self._translate_src_oprnd(oprnd2)
        op3_var, parent_reg_constrs = self._translate_dst_oprnd(oprnd3)

        if oprnd1.size < oprnd3.size:
            xor_zx = smtlibv2.ZEXTEND(op1_var ^ op2_var, oprnd3.size)

            expr = (op3_var == xor_zx)
        elif oprnd1.size > oprnd3.size:
            xor_extract = smtlibv2.EXTRACT(op1_var ^ op2_var, 0, oprnd3.size)

            expr = (op3_var == xor_extract)
        else:
            expr = (op3_var == (op1_var ^ op2_var))

        rv = [expr]

        if parent_reg_constrs:
            rv += parent_reg_constrs

        return rv
示例#9
0
    def _translate_dst_register_oprnd(self, operand):
        """Translate destination resgister operand to SMT expr.
        """
        reg_info = self._arch_alias_mapper.get(operand.name, None)

        if reg_info:
            var_base_name, offset = reg_info

            old_var_name = self._get_var_name(var_base_name, fresh=False)

            var_name = self._get_var_name(var_base_name, fresh=True)
            var_size = self._arch_regs_size[var_base_name]

            ret_val = self._solver.mkBitVec(var_size, var_name)

            ret_val_cpy = ret_val

            ret_val = smtlibv2.EXTRACT(ret_val, offset, operand.size)

            old_ret_val = self._solver.mkBitVec(var_size, old_var_name)

            constrs = []

            for i in reversed(xrange(0, var_size, 8)):
                if i >= offset and i < offset + operand.size:
                    continue

                bytes_exprs_1 = smtlibv2.EXTRACT(ret_val_cpy, i, 8)
                bytes_exprs_2 = smtlibv2.EXTRACT(old_ret_val, i, 8)

                constrs += [bytes_exprs_1 == bytes_exprs_2]

            parent_reg_constrs = constrs
        else:
            var_name = self._get_var_name(operand.name, fresh=True)
            ret_val = self._solver.mkBitVec(operand.size, var_name)

            parent_reg_constrs = None

        return ret_val, parent_reg_constrs
示例#10
0
    def _translate_str(self, oprnd1, oprnd2, oprnd3):
        """Return a formula representation of a STR instruction.
        """
        assert oprnd1.size and oprnd3.size

        op1_var = self._translate_src_oprnd(oprnd1)
        op3_var, parent_reg_constrs = self._translate_dst_oprnd(oprnd3)

        dst_size = op3_var.size

        constrs = []

        if oprnd1.size == oprnd3.size:
            expr = (op1_var == op3_var)
        elif oprnd1.size < oprnd3.size:
            expr = (op1_var == smtlibv2.EXTRACT(op3_var, 0, op1_var.size))

            # Make sure that the values that can take dst operand
            # do not exceed the range of the source operand.
            # TODO: Find a better way to enforce this.
            fmt = "#b%0{0}d".format(op3_var.size - op1_var.size)
            imm = smtlibv2.BitVec(op3_var.size - op1_var.size, fmt % 0)

            constrs = [(imm == smtlibv2.EXTRACT(op3_var, op1_var.size,
                                                op3_var.size - op1_var.size))]
        else:
            expr = (smtlibv2.EXTRACT(op1_var, 0, op3_var.size) == op3_var)

        rv = [expr]

        if constrs:
            rv += constrs

        if parent_reg_constrs:
            rv += parent_reg_constrs

        return rv
示例#11
0
    def _translate_src_register_oprnd(self, operand):
        """Translate source resgister operand to SMT expr.
        """
        reg_info = self._arch_alias_mapper.get(operand.name, None)

        if reg_info:
            var_base_name, offset = reg_info

            var_name = self._get_var_name(var_base_name)
            var_size = self._arch_regs_size[var_base_name]

            ret_val = self._solver.mkBitVec(var_size, var_name)
            ret_val = smtlibv2.EXTRACT(ret_val, offset, operand.size)
        else:
            var_name = self._get_var_name(operand.name)
            ret_val = self._solver.mkBitVec(operand.size, var_name)

        return ret_val
示例#12
0
    def _get_constrs_store_memory(self, gadget):
        """Generate constraints for the StoreMemory gadget: mem[dst_reg + offset] <- src_reg
        """
        if isinstance(gadget.destination[0], ReilRegisterOperand) and \
            not isinstance(gadget.destination[0], ReilEmptyOperand):
            base_addr = self.analyzer.get_register_expr(
                gadget.destination[0].name, mode="pre")
            offset = self.analyzer.get_immediate_expr(
                gadget.destination[1].immediate, gadget.destination[1].size)

            addr = base_addr + offset
        else:
            addr = self.analyzer.get_immediate_expr(
                gadget.destination[1].immediate, gadget.destination[1].size)

        src = self.analyzer.get_register_expr(gadget.sources[0].name,
                                              mode="pre")
        size = gadget.sources[0].size

        constrs = []

        for i in reversed(xrange(0, size, 8)):
            bytes_exprs_1 = self.analyzer.get_memory_expr(addr + i / 8, 8 / 8)
            bytes_exprs_2 = smtlibv2.EXTRACT(src, i, 8)

            constrs += [bytes_exprs_1 != bytes_exprs_2]

        # Check all non-modified registers don't change.
        constrs_mod = []

        for name in self._arch_info.registers_gp_base:
            if name not in [r.name for r in gadget.modified_registers]:
                var_initial = self.analyzer.get_register_expr(name, mode="pre")
                var_final = self.analyzer.get_register_expr(name, mode="post")

                constrs_mod += [var_initial != var_final]

        if constrs_mod:
            constrs_mod = [
                reduce(lambda c, acc: acc | c, constrs_mod[1:], constrs_mod[0])
            ]

        return constrs + constrs_mod
示例#13
0
    def _translate_and(self, oprnd1, oprnd2, oprnd3):
        """Return a formula representation of a AND instruction.
        """
        assert oprnd1.size and oprnd2.size and oprnd3.size
        assert oprnd1.size == oprnd2.size

        op1_var = self._translate_src_oprnd(oprnd1)
        op2_var = self._translate_src_oprnd(oprnd2)
        op3_var, _ = self._translate_dst_oprnd(oprnd3)

        if oprnd1.size < oprnd3.size:
            and_zx = smtlibv2.ZEXTEND(op1_var & op2_var, oprnd3.size)

            expr = (op3_var == and_zx)
        elif oprnd1.size > oprnd3.size:
            and_extract = smtlibv2.EXTRACT(op1_var & op2_var, 0, oprnd3.size)

            expr = (op3_var == and_extract)
        else:
            expr = (op3_var == (op1_var & op2_var))

        return [expr]