def visit_MLIL_LOAD_SSA(self, expr):
        src = self.visit(expr.src)

        if src is None:
            return

        memory = self._memory

        # we're assuming Little Endian for now
        if expr.size == 1:
            return memory[src]
        elif expr.size == 2:
            return Concat(memory[src+1], memory[src])
        elif expr.size == 4:
            return Concat(
                memory[src+3],
                memory[src+2],
                memory[src+1],
                memory[src]
            )
        elif expr.size == 8:
            return Concat(
                memory[src+7],
                memory[src+6],
                memory[src+5],
                memory[src+4],
                memory[src+3],
                memory[src+2],
                memory[src+1],
                memory[src]
            )
Exemple #2
0
    def mload(self, offset: BitVecNumRef):
        if isinstance(offset, BitVecNumRef):
            offset = offset.as_long()
        elif not isinstance(offset, int):
            raise DevelopmentErorr(
                'Does not support memory operations indexed by symbol variables.'
            )

        if offset + WORDBYTESIZE > len(self.__immediate_data):
            # ~ index out of bounds ~
            # generate a symblolic variable
            newmemvar = self.__generateMemoryVar()
            d = offset + WORDBYTESIZE - len(self.__immediate_data)
            if d < WORDBYTESIZE:
                for i in range(d):
                    self.__immediate_data.append(
                        Extract((d - i - 1) * 8 + 7, (d - i - 1) * 8,
                                newmemvar))
                return simplify(
                    Concat(self.__immediate_data[offset:WORDBYTESIZE +
                                                 offset]))
            else:
                self.mstore(BitVecVal256(offset), newmemvar)
                return newmemvar

        else:
            return simplify(
                Concat(self.__immediate_data[offset:WORDBYTESIZE + offset]))
def right_one_extension(formula, bit_places):
    """Set the rest of bits on the right to 1.
    """
    complement = BitVecVal(0, formula.size() - bit_places) - 1
    formula = Concat(
        Extract(formula.size() - 1,
                formula.size() - bit_places, formula), complement)

    return formula
Exemple #4
0
    def is_byte_swap(self):
        try:
            self.model_variable()
        except ModelIsConstrained:
            return False

        # Figure out if this might be a byte swap
        byte_values_len = len(self.byte_values)
        #print self.byte_values
        if 1 < byte_values_len <= self.var.src.var.type.width:
            var = create_BitVec(self.var.src, self.var.src.var.type.width)
           
            ordering = list(reversed([
                self.byte_values[x]
                for x in sorted(self.byte_values.keys())
            ]))

            reverse_var = Concat(
                *reversed([
                    Extract(i-1, i-8, var)
                    for i in range(len(ordering) * 8, 0, -8)
                ])
            )

            if len(ordering) < 4:
                reverse_var = Concat(
                    Extract(
                        31,
                        len(ordering)*8, var
                    ),
                    reverse_var
                )

            reversed_ordering = reversed(ordering)
            reversed_ordering = Concat(*reversed_ordering)

            # The idea here is that if we add the negation of this, if it's
            # not satisfiable, then that means there is no value such that
            # the equivalence does not hold. If that is the case, then this
            # should be a byte-swapped value.

            self.solver.add(
                Not(
                    And(
                        var == ZeroExt(
                            var.size() - len(ordering)*8,
                            Concat(*ordering)
                        ),
                        reverse_var == ZeroExt(
                            reverse_var.size() - reversed_ordering.size(),
                            reversed_ordering
                        )
                    )
                )
            )

            if self.solver.check() == unsat:
                return True

        return False
def sign_extension(formula, bit_places):
    """Set the rest of bits on the left to the value of the sign bit.
    """
    sign_bit = Extract(bit_places - 1, bit_places - 1, formula)

    complement = sign_bit
    for _ in range(formula.size() - bit_places - 1):
        complement = Concat(sign_bit, complement)

    formula = Concat(complement, (Extract(bit_places - 1, 0, formula)))

    return formula
def right_sign_extension(formula, bit_places):
    """Set the rest of bits on the right to the value of the sign bit.
    """
    sign_bit_position = formula.size() - bit_places
    sign_bit = Extract(sign_bit_position, sign_bit_position, formula)

    complement = sign_bit
    for _ in range(sign_bit_position - 1):
        complement = Concat(sign_bit, complement)

    formula = Concat(Extract(formula.size() - 1, sign_bit_position, formula),
                     complement)

    return formula
def zero_extension(formula, bit_places):
    """Set the rest of bits on the left to 0.
    """
    complement = BitVecVal(0, formula.size() - bit_places)
    formula = Concat(complement, (Extract(bit_places - 1, 0, formula)))

    return formula
Exemple #8
0
    def test_instruction_semantics_call(self):
        # call #0xdead
        raw = b'\xb0\x12\xad\xde'
        ip = 0xc0de

        ins, _ = decode_instruction(ip, raw)

        state = blank_state()
        state.cpu.registers['R0'] = BitVecVal(ip + len(raw),
                                              16)  # ip preincrement
        state.cpu.registers['R1'] = BitVecVal(0x1234, 16)

        new_states = state.cpu.step_call(state, ins)

        self.assertEqual(len(new_states), 1)

        new_state = new_states[0]

        lo = new_state.memory[0x1232]
        hi = new_state.memory[0x1233]
        pushed_val = Concat(hi, lo)

        self.assertEqual(intval(pushed_val), ip + len(raw))
        self.assertEqual(intval(new_state.cpu.registers['R1']), 0x1232)
        self.assertEqual(intval(new_state.cpu.registers['R0']), 0xdead)
Exemple #9
0
 def array_to_bv64(array):
     return Concat(Select(array, BitVecVal(7, 32)),
                   Select(array, BitVecVal(6, 32)),
                   Select(array, BitVecVal(5, 32)),
                   Select(array, BitVecVal(4, 32)),
                   Select(array, BitVecVal(3, 32)),
                   Select(array, BitVecVal(2, 32)),
                   Select(array, BitVecVal(1, 32)),
                   Select(array, BitVecVal(0, 32)))
Exemple #10
0
 def set_region_bit(bv, p):
   i = region_names.index(REGIONS[p.y][p.x])
   chunks = []
   if i < bits - 1:
     chunks.append(Extract(bits - 1, i + 1, bv))
   chunks.append(BitVecVal(1, 1))
   if i > 0:
     chunks.append(Extract(i - 1, 0, bv))
   return Concat(*chunks)
Exemple #11
0
def identity(data: Union[bytes, str, List[int]]) -> bytes:
    # Group up into an array of 32 byte words instead
    # of an array of bytes. If saved to memory, 32 byte
    # words are currently needed, but a correct memory
    # implementation would be byte indexed for the most
    # part.
    return data
    result = []
    for i in range(0, len(data), 32):
        result.append(simplify(Concat(data[i:i + 32])))
    return result
Exemple #12
0
    def byte_(self, global_state):
        mstate = global_state.mstate
        op0, op1 = mstate.stack.pop(), mstate.stack.pop()

        try:
            index = util.get_concrete_int(op0)
            offset = (31 - index) * 8
            result = Concat(BitVecVal(0, 248), Extract(offset + 7, offset,
                                                       op1))
        except AttributeError:
            logging.debug("BYTE: Unsupported symbolic byte offset")
            result = BitVec(str(simplify(op1)) + "_" + str(simplify(op0)), 256)

        mstate.stack.append(simplify(result))
        return [global_state]
    def byte_(self, global_state):
        mstate = global_state.mstate
        op0, op1 = mstate.stack.pop(), mstate.stack.pop()
        if not isinstance(op1, ExprRef):
            op1 = BitVecVal(op1, 256)
        try:
            index = util.get_concrete_int(op0)
            offset = (31 - index) * 8
            if offset >= 0:
                result = simplify(Concat(BitVecVal(0, 248), Extract(offset + 7, offset, op1)))
            else:
                result = 0
        except AttributeError:
            logging.debug("BYTE: Unsupported symbolic byte offset")
            result = BitVec(str(simplify(op1)) + "[" + str(simplify(op0)) + "]", 256)

        mstate.stack.append(result)
        return [global_state]
Exemple #14
0
    def test_instruction_semantics_push(self):
        # push #0xdead
        raw = b'\x30\x12\xad\xde'
        ip = 0x1234

        ins, _ = decode_instruction(ip, raw)

        state = blank_state()
        state.cpu.registers['R1'] = BitVecVal(0x1234, 16)

        new_states = state.cpu.step_push(state, ins)

        self.assertEqual(len(new_states), 1)

        new_state = new_states[0]
        lo = new_state.memory[0x1232]
        hi = new_state.memory[0x1233]
        pushed_val = Concat(hi, lo)

        self.assertEqual(intval(pushed_val), 0xdead)
        self.assertEqual(intval(new_state.cpu.registers['R1']), 0x1232)
Exemple #15
0
    def __getitem__(self, item):
        if isinstance(item, slice):
            try:
                current_index = (item.start if isinstance(
                    item.start, BitVecRef) else BitVecVal(item.start, 256))
                dataparts = []
                while simplify(current_index != item.stop):
                    dataparts.append(self[current_index])
                    current_index = simplify(current_index + 1)
            except Z3Exception:
                raise IndexError("Invalid Calldata Slice")

            return simplify(Concat(dataparts))

        if self.concrete:
            try:
                return self._calldata[get_concrete_int(item)]
            except IndexError:
                return BitVecVal(0, 8)
        else:
            return self._calldata[item]
Exemple #16
0
    def __getitem__(self, item: Union[int, slice]) -> Any:
        if isinstance(item, int) or isinstance(item, ExprRef):
            return self._load(item)

        if isinstance(item, slice):
            start = 0 if item.start is None else item.start
            step = 1 if item.step is None else item.step
            stop = self.size if item.stop is None else item.stop

            try:
                current_index = (
                    start if isinstance(start, BitVecRef) else BitVecVal(start, 256)
                )
                parts = []
                while simplify(current_index != stop):
                    parts.append(self._load(current_index))
                    current_index = simplify(current_index + step)
            except Z3Exception:
                raise IndexError("Invalid Calldata Slice")

            return simplify(Concat(parts))

        raise ValueError
Exemple #17
0
    def __getitem__(self, item: Union[int, slice]) -> Any:
        if isinstance(item, slice):
            start, step, stop = item.start, item.step, item.stop
            try:
                if start is None:
                    start = 0
                if step is None:
                    step = 1
                if stop is None:
                    stop = self.calldatasize
                current_index = (start if isinstance(start, BitVecRef) else
                                 BitVecVal(start, 256))
                dataparts = []
                while simplify(current_index != stop):
                    dataparts.append(self[current_index])
                    current_index = simplify(current_index + step)
            except Z3Exception:
                raise IndexError("Invalid Calldata Slice")

            values, constraints = zip(*dataparts)
            result_constraints = []
            for c in constraints:
                result_constraints.extend(c)
            return simplify(Concat(values)), result_constraints

        if self.concrete:
            try:
                return self._calldata[get_concrete_int(item)], ()
            except IndexError:
                return BitVecVal(0, 8), ()
        else:
            constraints = [
                Implies(self._calldata[item] != 0, UGT(self.calldatasize,
                                                       item))
            ]

            return self._calldata[item], constraints
Exemple #18
0
def BVSignedUpCast(x, n_bits):
	assert x.size() <= n_bits
	if x.size() < n_bits:
		return Concat(If(x < 0, BitVecVal(-1, n_bits - x.size()), BitVecVal(0, n_bits - x.size())), x)
	else:
		return x
Exemple #19
0
from pwn import *
import itertools
from z3 import BitVec, BitVecVal, Solver, If, simplify, Concat, Extract, And, Or, LShR, UDiv, sat, Int
from ictf import iCTF
from multiprocessing import Pool
i = iCTF()
team = i.login('*****@*****.**', 'HZBrKynG4XAMvWPa')

BVV = BitVecVal
BV = BitVec

int128 = lambda x: Concat(BVV(0, 64), x)
int64 = lambda x: Extract(63, 0, x)

def weird_op(x):
    k = x - (1337331 * LShR(int64(LShR(0x0C8B98A756AA1D561 * int128(x), 64)), 20))
    return int64(k)

def choose_op(a1, a2, op):
    r = If(op == 0,
            a2 + a1,
            If(op == 1,
                a1 - a2,
                If(op == 2,
                    a2 * a1,
                    If(op == 3 and a2 != 0,
                        UDiv(a1, a2),
                        False
                    )
                )
            )
Exemple #20
0
def make_load(src, size):
    mem = Array('mem', BitVecSort(32), BitVecSort(8))

    load_bytes = [mem[src + i] for i in range(0, size)]

    return Concat(*load_bytes)
Exemple #21
0
s = Solver()
flag = [BitVec(f'f{i}', 8) for i in range(16)]

shuffle = bytes.fromhex('02060701050B090E030F04080A0C0D00')
add32 = bytes.fromhex('EFBEADDEADDEE1FE3713371366746367')
xor = bytes.fromhex('7658B4498D1A5F38D423F834EB86F9AA')
expected_prefix = b'CTF{}'[:-1]

for f, v in zip(flag, expected_prefix):
    s.add(f == v)

dest = [flag[b] for b in shuffle]

words = []
for i in range(0, len(dest), 4):
    words.append(Concat(dest[i + 3], dest[i + 2], dest[i + 1], dest[i]))

for i in range(len(words)):
    words[i] += struct.unpack('<I', add32[i * 4:][:4])[0]

for i in range(len(words)):
    j = i * 4
    dest[j + 0] = Extract(7, 0, words[i])
    dest[j + 1] = Extract(15, 8, words[i])
    dest[j + 2] = Extract(23, 16, words[i])
    dest[j + 3] = Extract(31, 24, words[i])

for i in range(len(dest)):
    dest[i] ^= xor[i]

for i in range(len(flag)):
from opcodes import BYTE
from rule import Rule
from z3 import BitVec, BitVecVal, Concat, Extract
"""
Checks that the byte opcode (implemented using shift) is equivalent to a
canonical definition of byte using extract.
"""

rule = Rule()

n_bits = 256
x = BitVec('X', n_bits)

for i in range(0, 32):
    # For Byte, i = 0 corresponds to most significant bit
    # But for extract i = 0 corresponds to the least significant bit
    lsb = 31 - i
    rule.check(
        BYTE(BitVecVal(i, n_bits), x),
        Concat(BitVecVal(0, n_bits - 8), Extract(8 * lsb + 7, 8 * lsb, x)))
def find_n_root_domains_ignoring_wildcards(solver: z3.Solver, symbols: Dict[int, List[str]],
                                           require_literal_dot_in_domain: bool,
                                           levels: int, max_finds: int):
    """If enough root domains are found, report vulnerability.

    :param levels: number of levels of domains to consider as part of the root domain. Either 2 or 3.
        If 3, the TLD needs to be a Country Code.
    :param max_finds: threshold number of root domains to find before considering as vulnerability.
    :return: tuple of sat result, witness or debug info
    """

    from z3 import Not, And, Or, Xor, Contains, String, StringVal, Length, Implies, SuffixOf, Concat

    tld = String('tld')
    sld = String('sld')
    third_ld = String('3ld')
    DNS_root = String('DNS_root')
    dot_in_domain = Contains(root_domain, '.')

    domain_expr = z3.simplify(And(
        Or(DNS_root == StringVal('.'), DNS_root == StringVal('')),

        Xor(Concat(root_domain, DNS_root) == fqdn,
            And(SuffixOf(Concat(StringVal('.'), root_domain, DNS_root), fqdn),
                dot_in_domain)),

        Implies(dot_in_domain,
                And(Not(Contains(tld, '.')),
                    Not(Contains(sld, '.')),
                    Length(tld) > 0,
                    root_domain == (Concat(sld, StringVal('.'), tld)
                                    if levels < 3
                                    else Concat(third_ld, StringVal('.'), sld, StringVal('.'), tld)))),

        Implies(Not(dot_in_domain), And(Concat(root_domain, DNS_root) == fqdn,
                                        tld == StringVal(''),
                                        sld == StringVal(''),)),
    ))


    with solver_frame(solver):
        solver.add(domain_expr)
        solver.add(RegexStringExpr.ignore_wildcards)
        if require_literal_dot_in_domain or levels >= 3:
            solver.add(dot_in_domain)

        if levels >= 3:
            solver.add(z3.simplify(And(
                Not(Contains(third_ld, '.')),
                Length(sld) > 0,
                Length(third_ld) > 0,
                Or(*(tld == StringVal(ss) for ss in cc_tlds)),
            )))

        results = []
        found = set()
        found_root_domains = set()
        result = z3.sat

        logger.info('searching for n root_domains')

        _concs = lambda zs: concretizations(z3_str_to_bytes(zs), symbols)
        solution = None

        while len(found_root_domains) < max_finds and result == z3.sat:
            result = solver.check()
            results.append(result)
            if result == z3.sat:
                model = solver.model()
                _root_domain = model[root_domain]
                parts = (model[proto], model[proto_delimiter], model[fqdn])
                assert all(part is not None for part in parts)
                logger.info(public_vars(model))
                found.update(_concs(z3.simplify(z3.Concat(*parts))))
                found_root_domains.update(_concs(_root_domain))
                solver.add(root_domain != _root_domain)
                solution = tz.first(_concs(model[unknown_string]))

        # if not sat, return the would-be witness for debugging
        if result == z3.sat:
            root_domains_label = 'witness'
        else:
            root_domains_label = 'root_domains'

        return result, {'strategy': 'find_n_root_domains_ignoring_wildcards',
                        'found': list(found),
                        root_domains_label: list(found_root_domains),
                        'levels': levels,
                        'solution': solution}
Exemple #24
0
def to_smt(r):
    # type: (Rtl) -> Tuple[List[ExprRef], Z3VarMap]
    """
    Encode a concrete primitive Rtl r sa z3 query.
    Returns a tuple (query, var_m) where:
        - query is a list of z3 expressions
        - var_m is a map from Vars v with non-BVType to their correspodning z3
          bitvector variable.
    """
    assert r.is_concrete()
    # Should contain only primitives
    primitives = set(PRIMITIVES.instructions)
    assert set(d.expr.inst for d in r.rtl).issubset(primitives)

    q = []  # type: List[ExprRef]
    m = {}  # type: Z3VarMap

    # Build declarations for any bitvector Vars
    var_to_bv = {}  # type: Z3VarMap
    for v in r.vars():
        typ = v.get_typevar().singleton_type()
        if not isinstance(typ, BVType):
            continue

        var_to_bv[v] = BitVec(v.name, typ.bits)

    # Encode each instruction as a equality assertion
    for d in r.rtl:
        inst = d.expr.inst

        exp = None  # type: ExprRef
        # For prim_to_bv/prim_from_bv just update var_m. No assertion needed
        if inst == prim_to_bv:
            assert isinstance(d.expr.args[0], Var)
            m[d.expr.args[0]] = var_to_bv[d.defs[0]]
            continue

        if inst == prim_from_bv:
            assert isinstance(d.expr.args[0], Var)
            m[d.defs[0]] = var_to_bv[d.expr.args[0]]
            continue

        if inst in [bvadd, bvult]:  # Binary instructions
            assert len(d.expr.args) == 2 and len(d.defs) == 1
            lhs = d.expr.args[0]
            rhs = d.expr.args[1]
            df = d.defs[0]
            assert isinstance(lhs, Var) and isinstance(rhs, Var)

            if inst == bvadd:  # Normal binary - output type same as args
                exp = (var_to_bv[lhs] + var_to_bv[rhs])
            else:
                assert inst == bvult
                exp = (var_to_bv[lhs] < var_to_bv[rhs])
                # Comparison binary - need to convert bool to BitVec 1
                exp = If(exp, BitVecVal(1, 1), BitVecVal(0, 1))

            exp = mk_eq(var_to_bv[df], exp)
        elif inst == bvzeroext:
            arg = d.expr.args[0]
            df = d.defs[0]
            assert isinstance(arg, Var)
            fromW = arg.get_typevar().singleton_type().width()
            toW = df.get_typevar().singleton_type().width()

            exp = mk_eq(var_to_bv[df], ZeroExt(toW - fromW, var_to_bv[arg]))
        elif inst == bvsignext:
            arg = d.expr.args[0]
            df = d.defs[0]
            assert isinstance(arg, Var)
            fromW = arg.get_typevar().singleton_type().width()
            toW = df.get_typevar().singleton_type().width()

            exp = mk_eq(var_to_bv[df], SignExt(toW - fromW, var_to_bv[arg]))
        elif inst == bvsplit:
            arg = d.expr.args[0]
            assert isinstance(arg, Var)
            arg_typ = arg.get_typevar().singleton_type()
            width = arg_typ.width()
            assert (width % 2 == 0)

            lo = d.defs[0]
            hi = d.defs[1]

            exp = And(
                mk_eq(var_to_bv[lo], Extract(width // 2 - 1, 0,
                                             var_to_bv[arg])),
                mk_eq(var_to_bv[hi],
                      Extract(width - 1, width // 2, var_to_bv[arg])))
        elif inst == bvconcat:
            assert isinstance(d.expr.args[0], Var) and \
                isinstance(d.expr.args[1], Var)
            lo = d.expr.args[0]
            hi = d.expr.args[1]
            df = d.defs[0]

            # Z3 Concat expects hi bits first, then lo bits
            exp = mk_eq(var_to_bv[df], Concat(var_to_bv[hi], var_to_bv[lo]))
        else:
            assert False, "Unknown primitive instruction {}".format(inst)

        q.append(exp)

    return (q, m)
Exemple #25
0
 def array_to_bv32(array):
     return Concat(Select(array, BitVecVal(3, 32)),
                   Select(array, BitVecVal(2, 32)),
                   Select(array, BitVecVal(1, 32)),
                   Select(array, BitVecVal(0, 32)))
Exemple #26
0
def BVUnsignedUpCast(x, n_bits):
	assert x.size() <= n_bits
	if x.size() < n_bits:
		return Concat(BitVecVal(0, n_bits - x.size()), x)
	else:
		return x
n_bits = 256

# Check that YulUtilFunction::cleanupFunction cleanup matches BVSignedCleanupFunction
for type_bits in range(8, 256, 8):

    rule = Rule()

    # Input vars
    X = BitVec('X', n_bits)
    arg = BitVecVal(type_bits / 8 - 1, n_bits)

    cleaned_reference = BVSignedCleanupFunction(X, type_bits)
    cleaned = SIGNEXTEND(arg, X)

    rule.check(cleaned, cleaned_reference)

# Check that BVSignedCleanupFunction properly cleans up values.
for type_bits in range(8, 256, 8):

    rule = Rule()

    # Input vars
    X_short = BitVec('X', type_bits)
    dirt = BitVec('dirt', n_bits - type_bits)

    X = BVSignedUpCast(X_short, n_bits)
    X_dirty = Concat(dirt, X_short)
    X_cleaned = BVSignedCleanupFunction(X_dirty, type_bits)

    rule.check(X, X_cleaned)