Example #1
0
 def verify(self, gadget):
     if( self.ret == self.jmp == self.call == False ):
         return (True, [])
     elif( self.ret and gadget.retType == RetType.RET ):
         return (True, [])
     elif( self.jmp and gadget.retType == RetType.JMP ):
         return (True, [])
     elif( self.call and gadget.retType == RetType.CALL ):
         return (True, [])
         
     # If unknown ret, check if ret possible
     # Is a ret sometimes possible ? 
     if( self.ret ):
         for p in gadget.getSemantics(Arch.ipNum()):
             if( isinstance(p.expr, MEMExpr)):
                 addr = p.expr.addr
                 (isInc, inc) = addr.isRegIncrement(Arch.spNum())    
                 # Normal ret if the final value of the IP is value that was in memory before the last modification of SP ( i.e final_IP = MEM[final_sp - size_of_a_register )        
                 if( isInc and inc == (gadget.spInc - (Arch.octets())) ):
                     return (True, [p.cond])
     # Or a jump ? 
     if( self.jmp ):
         for p in gadget.getSemantics(Arch.ipNum()):
             if( isinstance(p.expr, SSAExpr )):
                 return (True, [p.cond])
     return (False, [])
Example #2
0
def call(funcName, parsedArgs, constraint, assertion):
    # Get target system
    if (Arch.currentBinType == Arch.BinaryType.X86_ELF):
        syscall = Linux32.available.get(funcName)
        system = sysLinux32
    elif (Arch.currentBinType == Arch.BinaryType.X64_ELF):
        syscall = Linux64.available.get(funcName)
        system = sysLinux64
    else:
        error("Binary type '{}' not supported yet".format(Arch.currentBinType))
        return

    if (not syscall):
        error("Syscall '{}' not supported for system '{}'".format(\
        funcName, system))
        return

    if (len(parsedArgs) != len(syscall.args)):
        error("Error. Wrong number of arguments")
        return

    # Build syscall
    res = _build_syscall(syscall.buildFunc, parsedArgs, constraint, assertion)
    # Print result
    if (not res):
        print(string_bold("\n\tNo matching ROPChain found"))
    else:
        print(string_bold("\n\tFound matching ROPChain\n"))
        badBytes = constraint.getBadBytes()
        if (OUTPUT == OUTPUT_CONSOLE):
            print(res.strConsole(Arch.bits(), badBytes))
        elif (OUTPUT == OUTPUT_PYTHON):
            print(res.strPython(Arch.bits(), badBytes))
Example #3
0
def CSTtoMEM_write(arg1, cst, constraint, assertion, n=1, clmax=LMAX):
    """
    reg <- cst 
    mem(arg2) <- reg
    """
    if (clmax <= 0):
        return []

    res = []
    addr_reg = arg1[0]
    addr_cst = arg1[1]
    # 1. First strategy (direct)
    # reg <- cst
    # mem(arg1) <- reg
    for reg in range(0, Arch.ssaRegCount):
        if (reg == Arch.ipNum() or reg == Arch.spNum() or reg == addr_reg):
            continue
        # Find reg <- cst
        # maxdepth 3 or it's too slow
        cst_to_reg_chains = search(QueryType.CSTtoREG,
                                   reg,
                                   cst,
                                   constraint.add(RegsNotModified([addr_reg])),
                                   assertion,
                                   n,
                                   clmax - 1,
                                   maxdepth=3)
        if (not cst_to_reg_chains):
            continue
        # Search for mem(arg1) <- reg
        # We get all reg2,cst2 s.t mem(arg1) <- reg2+cst2
        possible_mem_writes = DBPossibleMemWrites(addr_reg,
                                                  addr_cst,
                                                  constraint,
                                                  assertion,
                                                  n=1)
        # 1.A. Ideally we look for reg2=reg and cst2=0 (direct_writes)
        possible_mem_writes_reg = possible_mem_writes.get(reg)
        if (possible_mem_writes_reg):
            direct_writes = possible_mem_writes[reg].get(0, [])
        else:
            direct_writes = []
        padding = constraint.getValidPadding(Arch.octets())
        for write_gadget in direct_writes:
            for cst_to_reg_chain in cst_to_reg_chains:
                # Pad the gadgets
                write_chain = ROPChain([write_gadget])
                for i in range(0, write_gadget.spInc - Arch.octets(),
                               Arch.octets()):
                    write_chain.addPadding(padding)
                full_chain = cst_to_reg_chain.addChain(write_chain, new=True)
                if (len(full_chain) <= clmax):
                    res.append(full_chain)
                if (len(res) >= n):
                    return res
        # 1.B.
    return res
Example #4
0
def _CSTtoREG_transitivity(reg, cst, env, n=1):
    """
    Perform REG1 <- CST with REG1 <- REG2 <- CST
    """
    ID = StrategyType.CSTtoREG_TRANSITIVITY

    ## Test for special cases
    # Test lmax
    if (env.getLmax() <= 0):
        return []
    # Limit number of calls to ...
    elif (env.nbCalls(ID) >= 99):
        return []
    # Check if the cst is in badBytes
    elif (not env.getConstraint().badBytes.verifyAddress(cst)):
        return []
    # Check if previous call was already CSTtoREG_transitivity
    # Reason: we handle the transitivity with REGtoREG transitivity
    # so no need to do it also recursively with this one ;)
    elif (env.callsHistory()[-1] == ID):
        return []

    # Set env
    env.addCall(ID)

    #############################
    res = []
    for inter in Arch.registers():
        if (inter == reg or inter in env.getConstraint().getRegsNotModified()
                or inter == Arch.ipNum() or inter == Arch.spNum()):
            continue
        # Find reg <- inter
        inter_to_reg = _search(QueryType.REGtoREG, reg, (inter, 0), env, n)
        if (inter_to_reg):
            # We found ROPChains s.t reg <- inter
            # Now we want inter <- cst
            min_len = min([len(chain) for chain in inter_to_reg])
            env.subLmax(min_len)
            env.addUnusableReg(reg)
            cst_to_inter = _search(QueryType.CSTtoREG, inter, cst, env,
                                   n / len(inter_to_reg) + 1)
            env.removeUnusableReg(reg)
            env.addLmax(min_len)

            for chain2 in inter_to_reg:
                for chain1 in cst_to_inter:
                    if (len(chain1) + len(chain2) <= env.getLmax()):
                        res.append(chain1.addChain(chain2, new=True))

        # Did we get enough chains ?
        if (len(res) >= n):
            break
    ###############################
    # Restore env
    env.removeCall(ID)

    return res[:n]
Example #5
0
def _MEMtoREG_transitivity(reg, arg2, env, n=1):
    """
    Perform reg <- inter <- mem(arg2)
    """
    ID = StrategyType.MEMtoREG_TRANSITIVITY

    ## Test for special cases
    # Test lmax
    if (env.getLmax() <= 0):
        return FailRecord(lmax=True)
    # Limit number of calls to ...
    elif (env.nbCalls(ID) >= 99):
        return FailRecord()
    # Check if previous call was already MEMtoREG_transitivity
    # Reason: we handle the transitivity with REGtoREG transitivity
    # so no need to do it also recursively with this one ;)
    elif (env.callsHistory()[-1] == ID):
        return FailRecord()

    # Set env
    env.addCall(ID)

    ###########################
    res = []
    res_fail = FailRecord()
    for inter in Arch.registers():
        if (inter == reg or inter in env.getConstraint().getRegsNotModified()
                or inter == Arch.ipNum() or inter == Arch.spNum()):
            continue

        # Find arg1 <- inter
        inter_to_reg = _search(QueryType.REGtoREG, reg, (inter, 0), env, n)
        if (inter_to_reg):
            min_len = min([len(chain) for chain in inter_to_reg])
            # Try to find inter <- arg2
            env.subLmax(min_len)
            env.addUnusableReg(reg)
            arg2_to_inter = _search(QueryType.MEMtoREG, inter, arg2, env, n)
            env.removeUnusableReg(reg)
            env.addLmax(min_len)
            if (not arg2_to_inter):
                res_fail.merge(arg2_to_inter)
                continue
            res += [chain1.addChain(chain2, new=True) for chain1 in arg2_to_inter \
                for chain2 in inter_to_reg if len(chain1)+len(chain2) <= env.getLmax()  ]
        else:
            res_fail.merge(inter_to_reg)
        # Did we get enough chains ?
        if (len(res) >= n):
            break
    ########################

    # Restore env
    env.removeCall(ID)
    return res[:n] if res else res_fail
Example #6
0
def _CSTtoREG_transitivity(reg,
                           cst,
                           constraint,
                           assertion,
                           n=1,
                           clmax=LMAX,
                           comment=None,
                           maxdepth=4):
    """
    Perform REG1 <- CST with REG1 <- REG2 <- CST
    """
    # Test clmax
    if (clmax <= 0):
        return []

    res = []
    for inter in range(0, Arch.ssaRegCount):
        if (inter == reg or inter in constraint.getRegsNotModified()
                or inter == Arch.ipNum() or inter == Arch.spNum()):
            continue
        # Find reg <- inter
        REGtoREG_record = SearchRecord(maxdepth=maxdepth)
        REGtoREG_record.unusable_REGtoREG.append(reg)
        inter_to_reg = search(QueryType.REGtoREG,
                              reg, (inter, 0),
                              constraint,
                              assertion,
                              n,
                              clmax,
                              record=REGtoREG_record)
        if (inter_to_reg):
            # We found ROPChains s.t reg <- inter
            # Now we want inter <- cst
            cst_to_inter = _basic(QueryType.CSTtoREG, inter, cst, constraint,
                                  assertion, n / len(inter_to_reg) + 1,
                                  clmax - 1)
            for chain2 in inter_to_reg:
                for chain1 in cst_to_inter:
                    if (len(chain1) + len(chain2) <= clmax):
                        res.append(chain1.addChain(chain2, new=True))
            if (len(res) < n):
                cst_to_inter = _CSTtoREG_pop(inter, cst, constraint, assertion,
                                             n / (len(inter_to_reg)) + 1,
                                             clmax - 1, comment)
                for chain2 in inter_to_reg:
                    for chain1 in cst_to_inter:
                        if (len(chain1) + len(chain2) <= clmax):
                            res.append(chain1.addChain(chain2, new=True))

        # Did we get enough chains ?
        if (len(res) >= n):
            return res[:n]
    # Return what we got
    return res
Example #7
0
def _REGtoMEM_transitivity(arg1, arg2, env, n=1):
    """
    reg <- arg2
    mem(arg1) <- reg
    """
    ID = StrategyType.REGtoMEM_TRANSITIVITY

    ## Test for special cases
    # Test lmax
    if (env.getLmax() <= 0):
        return FailRecord(lmax=True)
    # Limit number of calls to ...
    elif (env.nbCalls(ID) >= 99):
        return FailRecord()
    # Check if previous call was already REGtoMEM_transitivity
    # Reason: we handle the transitivity with REGtoREG transitivity
    # so no need to do it also recursively with this one ;)
    elif (env.callsHistory()[-1] == ID):
        return FailRecord()

    # Set env
    env.addCall(ID)
    ###################################
    res = []
    res_fail = FailRecord()
    for inter in Arch.registers():
        if (inter == arg2[0]
                or inter in env.getConstraint().getRegsNotModified()
                or inter == Arch.ipNum() or inter == Arch.spNum()):
            continue
        # Find inter <- arg2
        arg2_to_inter = _search(QueryType.REGtoREG, inter, (arg2[0], arg2[1]),
                                env, n)
        if (arg2_to_inter):
            len_min = min([len(chain) for chain in arg2_to_inter])
            # Try to find mem(arg1) <- inter
            env.subLmax(len_min)
            env.addUnusableReg(arg2[0])
            inter_to_mem = _search(QueryType.REGtoMEM, arg1, (inter, 0), env)
            env.removeUnusableReg(arg2[0])
            env.addLmax(len_min)
            if (not inter_to_mem):
                res_fail.merge(inter_to_mem)
                continue
            res += [chain1.addChain(chain2, new=True) for chain1 in arg2_to_inter\
                for chain2 in inter_to_mem if len(chain1)+len(chain2) <= env.getLmax()]
            if (len(res) >= n):
                break
        else:
            res_fail.merge(arg2_to_inter)
    #####################################
    # Resotre env
    env.removeCall(ID)
    return res if res else res_fail
Example #8
0
def build_mprotect64(addr, size, prot=7, constraint=None, assertion=None, clmax=SYSCALL_LMAX, optimizeLen=False):
    """
    Call mprotect from X86-64 arch
    Args must be on registers (rdi, rsi, rdx):
    Sizes are (unsigned long, size_t, unsigned long)
    rax must be 10 
    """
    # Check args
    if not isinstance(addr, int):
        error("Argument error. Expected integer, got " + str(type(addr)))
        return None
    elif not isinstance(size, int):
        error("Argument error. Expected integer, got " + str(type(size)))
        return None
    elif not isinstance(prot, int):
        error("Argument error. Expected integer, got " + str(type(prot)))
        return None
    
    if( constraint is None ):
        constraint = Constraint()
    if( assertion is None ):
        assertion = Assertion()
    
    # Check if we have the function !
    verbose("Trying to call mprotect() function directly")
    func_call = build_call('mprotect', [addr, size, prot], constraint, assertion, clmax=clmax, optimizeLen=optimizeLen)
    if( not isinstance(func_call, str) ):
        verbose("Success")
        return func_call
    else:
        if( not constraint.chainable.ret ):
            verbose("Coudn't call mprotect(), try direct syscall")
        else:
            verbose("Couldn't call mprotect() and return to ROPChain")
            return None
    
    # Otherwise do the syscall by 'hand' 
    # Set the registers
    args = [[Arch.n2r('rdi'),addr],[Arch.n2r('rsi'), size],[Arch.n2r('rdx'),prot], [Arch.n2r('rax'),10]]
    chain = popMultiple(args, constraint, assertion, clmax-1, optimizeLen)
    if( not chain ):
        verbose("Failed to set registers for the mprotect syscall")
        return None
    # Syscall 
    syscalls = search(QueryType.SYSCALL, None, None, constraint, assertion)
    if( not syscalls ):
        verbose("Failed to find a syscall gadget")
        return None
    else:
        chain.addChain(syscalls[0])
    verbose("Success")
    return chain
Example #9
0
def build_call(funcName, funcArgs, constraint, assertion):
    # Find the address of the fonction
    (funcName2, funcAddr) = getFunctionAddress(funcName)
    if (funcName2 is None):
        return "Couldn't find function '{}' in the binary".format(funcName)

    # Check if bad bytes in function address
    if (not constraint.badBytes.verifyAddress(funcAddr)):
        return "'{}' address ({}) contains bad bytes".format(
            funcName2,
            string_special('0x' + format(funcAddr, '0' +
                                         str(Arch.octets() * 2) + 'x')))

    # Find a gadget for the fake return address
    offset = len(
        funcArgs) * 8 - 8  # Because we do +8 at the beginning of the loop
    skip_args_chains = []
    i = 4
    while (i > 0 and (not skip_args_chains)):
        offset += 8
        skip_args_chains = search(QueryType.MEMtoREG, Arch.ipNum(), \
                    (Arch.spNum(),offset), constraint, assertion, n=1)
        i -= 1

    if (not skip_args_chains):
        return "Couldn't build ROP-Chain"
    skip_args_chain = skip_args_chains[0]

    # Build the ropchain with the arguments
    args_chain = ROPChain()
    arg_n = len(funcArgs)
    for arg in reversed(funcArgs):
        if (isinstance(arg, int)):
            args_chain.addPadding(arg,
                                  comment="Arg{}: {}".format(
                                      arg_n, string_ropg(hex(arg))))
            arg_n -= 1
        else:
            return "Type of argument '{}' not supported yet :'(".format(arg)

    # Build call chain (function address + fake return address)
    call_chain = ROPChain()
    call_chain.addPadding(funcAddr, comment=string_ropg(funcName2))
    skip_args_addr = int(
        validAddrStr(skip_args_chain.chain[0], constraint.getBadBytes(),
                     Arch.bits()), 16)
    call_chain.addPadding(skip_args_addr,
                          comment="Address of: " +
                          string_bold(str(skip_args_chain.chain[0])))

    return call_chain.addChain(args_chain)
Example #10
0
def build_call_linux86(funcName, funcArgs, constraint, assertion, clmax=None, optimizeLen=False):
    # Find the address of the fonction 
    (funcName2, funcAddr) = getFunctionAddress(funcName)
    if( funcName2 is None ):
        return "Couldn't find function '{}' in the binary".format(funcName)
    
    # Check if bad bytes in function address 
    if( not constraint.badBytes.verifyAddress(funcAddr) ):
        return "'{}' address ({}) contains bad bytes".format(funcName2, string_special('0x'+format(funcAddr, '0'+str(Arch.octets()*2)+'x')))
    
    # Check if lmax too small
    if( (1 + len(funcArgs) + (lambda x: 1 if len(x)>0 else 0)(funcArgs)) > clmax ):
        return "Not enough bytes to call function '{}'".format(funcName)
    
    # Find a gadget for the fake return address
    if( funcArgs ):
        offset = (len(funcArgs)-1)*Arch.octets() # Because we do +octets() at the beginning of the loop
        skip_args_chains = []
        i = 4 # Try 4 more maximum 
        while( i > 0 and (not skip_args_chains)):
            offset += Arch.octets() 
            skip_args_chains = search(QueryType.MEMtoREG, Arch.ipNum(), \
                        (Arch.spNum(),offset), constraint, assertion, n=1, optimizeLen=optimizeLen)
            i -= 1
            
        if( not skip_args_chains ):
            return "Couldn't build ROP-Chain"
        skip_args_chain = skip_args_chains[0]
    else:
        # No arguments 
        skip_args_chain = None
    
    # Build the ropchain with the arguments 
    args_chain = ROPChain()
    arg_n = len(funcArgs)
    for arg in reversed(funcArgs):
        if( isinstance(arg, int) ):
            args_chain.addPadding(arg, comment="Arg{}: {}".format(arg_n, string_ropg(hex(arg))))
            arg_n -= 1
        else:
            return "Type of argument '{}' not supported yet :'(".format(arg)
    
    # Build call chain (function address + fake return address)
    call_chain = ROPChain()
    call_chain.addPadding(funcAddr, comment=string_ropg(funcName2))
    if( funcArgs ):
        skip_args_addr = int( validAddrStr(skip_args_chain.chain[0], constraint.getBadBytes(), Arch.bits())  ,16)
        call_chain.addPadding(skip_args_addr, comment="Address of: "+string_bold(str(skip_args_chain.chain[0])))
    
    return call_chain.addChain(args_chain)
Example #11
0
def init_impossible_REGtoREG(env):
    global INIT_LMAX, INIT_MAXDEPTH
    global baseAssertion

    try:
        startTime = datetime.now()
        i = 0
        impossible_count = 0
        for reg1 in sorted(Arch.registers()):
            reg_name = Arch.r2n(reg1)
            if (len(reg_name) < 6):
                reg_name += " " * (6 - len(reg_name))
            elif (len(reg_name) >= 6):
                reg_name = reg_name[:5] + "."
            for reg2 in Arch.registers():
                i += 1
                charging_bar(len(Arch.registers() * len(Arch.registers())), i,
                             30)
                if (reg2 == reg1 or reg2 == Arch.ipNum()):
                    continue
                _search(QueryType.REGtoREG, reg1, (reg2, 0), env, n=1)
                if (env.checkImpossible_REGtoREG(reg1, reg2, 0)):
                    impossible_count += 1
        cTime = datetime.now() - startTime
        # Get how many impossible path we found
        impossible_rate = int(100 * (float(impossible_count) / float(
            (len(Arch.registers()) - 1) * len(Arch.registers()))))
        notify('Optimization rate : {}%'.format(impossible_rate))
        notify("Computation time : " + str(cTime))
    except:
        print("\n")
        fatal("Exception caught, stopping Semantic Engine init process...\n")
        fatal("Search time might get very long !\n")
        env = SearchEnvironment(INIT_LMAX, Constraint(), baseAssertion,
                                INIT_MAXDEPTH)
Example #12
0
def init_impossible_REGtoREG(env):
    global INIT_LMAX, INIT_MAXDEPTH
    global baseAssertion
    # DEBUG
    #try:
    startTime = datetime.now()
    i = 0
    impossible_count = 0
    for reg1 in sorted(Arch.registers()):
        reg_name = Arch.r2n(reg1)
        if (len(reg_name) < 6):
            reg_name += " " * (6 - len(reg_name))
        elif (len(reg_name) >= 6):
            reg_name = reg_name[:5] + "."
        for reg2 in Arch.registers():
            i += 1
            charging_bar(len(Arch.registers() * len(Arch.registers())), i, 30)
            if (reg2 == reg1 or reg2 == Arch.ipNum()):
                continue
            _search(QueryType.REGtoREG, reg1, (reg2, 0), env, n=1)
            if (env.checkImpossible_REGtoREG(reg1, reg2, 0)):
                impossible_count += 1
    cTime = datetime.now() - startTime
    # Get how many impossible path we found
    impossible_rate = int(100 * (float(impossible_count) / float(
        (len(Arch.registers()) - 1) * len(Arch.registers()))))
    notify('Optimization rate : {}%'.format(impossible_rate))
    notify("Computation time : " + str(cTime))
Example #13
0
def build_call_linux64(funcName, funcArgs, constraint, assertion, clmax=None, optimizeLen=False):
    # Arguments registers 
    # (Args should go in these registers for x64)
    argsRegsNames = ['rdi','rsi','rdx','rcx', 'r8', 'r9']
    argsRegs = [Arch.n2r(name) for name in argsRegsNames]
    # Find the address of the fonction 
    (funcName2, funcAddr) = getFunctionAddress(funcName)
    if( funcName2 is None ):
        return "Couldn't find function '{}' in the binary".format(funcName)
    
    # Check if bad bytes in function address 
    if( not constraint.badBytes.verifyAddress(funcAddr) ):
        return "'{}' address ({}) contains bad bytes".format(funcName2, string_special('0x'+format(funcAddr, '0'+str(Arch.octets()*2)+'x')))
    
    # Check how many arguments 
    if( len(funcArgs) > 6 ):
        return "Doesn't support function call with more than 6 arguments with Linux X64 calling convention :("
        
    # Find a gadget for the fake return address
    if( funcArgs ):
        # Build the ropchain with the arguments
        args_chain = popMultiple(map(lambda x,y:(x,)+y,  argsRegs[:len(funcArgs)], funcArgs), constraint, assertion, clmax=clmax, optimizeLen=optimizeLen)
        if( not args_chain):
            return "Couldn't load arguments in registers"
    else:
        # No arguments 
        args_chain = ROPChain()
    
    # Build call chain (function address + fake return address)
    return args_chain.addPadding(funcAddr, comment=string_ropg(funcName2))
Example #14
0
def getFunctionAddress(name):
    """
    Looks for the function 'name' in the PLT of a binary 
    Returns a pair (name, address) as (str, int)
    """
    global binary_name
    global binary_ELF

    if (not Arch.currentIsELF()):
        return (None, None)

    # Get function in relocatins
    relasec_name = '.rela.plt'
    relasec = binary_ELF.get_section_by_name(relasec_name)
    if not isinstance(relasec, RelocationSection):
        print('  ERROR DEBUG The file has no %s section' % relasec_name)
    relasec_addr = relasec.header['sh_addr']
    symbols = binary_ELF.get_section(relasec.header['sh_link'])
    if (not isinstance(symbols, NullSection)):
        for reloc in relasec.iter_relocations():
            if (symbols.get_symbol(reloc['r_info_sym']).name == name):
                return (name, reloc['r_offset'] + relasec_addr)

    # Get function from symbol table sections
    for symsec in getSymbolSections():
        function = symsec.get_symbol_by_name(name)
        if (function):
            return (name, function[0]['st_value'])
    return (None, None)
Example #15
0
def find(args):
    """
    args - List of user arguments as strings
    (the command should not be included in the list as args[0])
    """
    if (not args):
        print_help()
        return

    if (args[0] == OPTION_HELP or args[0] == OPTION_HELP_SHORT):
        print_help()
        return

    parsed_args = parse_args(args)
    if (not parsed_args[0]):
        error(parsed_args[1])
    else:
        qtype = parsed_args[1]
        arg1 = parsed_args[2]
        arg2 = parsed_args[3]
        constraint = parsed_args[4]
        nbResults = parsed_args[5]
        clmax = parsed_args[6]
        assertion = Assertion().add(\
            RegsValidPtrRead([(Arch.spNum(),-5000, 10000)])).add(\
            RegsValidPtrWrite([(Arch.spNum(), -5000, 0)]))
        # Search
        res = search(qtype,
                     arg1,
                     arg2,
                     constraint,
                     assertion,
                     n=nbResults,
                     clmax=clmax)
        if (res):
            print_chains(res, "Built matching ROPChain(s)",
                         constraint.getBadBytes())
        else:
            res = search_not_chainable(qtype,
                                       arg1,
                                       arg2,
                                       constraint,
                                       assertion,
                                       n=nbResults,
                                       clmax=clmax)
            print_chains(res, "Possibly matching gadget(s)",
                         constraint.getBadBytes())
Example #16
0
def _REGtoREG_transitivity(arg1,
                           arg2,
                           constraint,
                           assertion,
                           record,
                           n=1,
                           clmax=LMAX):
    """
    Perform REG1 <- REG2+CST with REG1 <- REG3 <- REG2+CST
    """
    # Test clmax
    if (clmax <= 0):
        return []

    # If reg1 <- reg1 + 0, return
    if (arg1 == arg2[0] and arg2[1] == 0):
        return []

    res = []
    for inter_reg in range(0, Arch.ssaRegCount):
        if( inter_reg == arg1 or (inter_reg == arg2[0] and arg2[1]==0)\
            or (inter_reg in record.unusable_REGtoREG) or inter_reg == Arch.ipNum()\
            or (inter_reg == Arch.spNum()) ):
            continue
        # Find reg1 <- inter_reg without using arg2
        record.unusable_REGtoREG.append(arg2[0])
        inter_to_arg1_list = search(QueryType.REGtoREG, arg1, (inter_reg, 0), \
                constraint, assertion, n, clmax=clmax-1, record=record )
        record.unusable_REGtoREG.remove(arg2[0])
        if (not inter_to_arg1_list):
            continue

        # Find inter_reg <- arg2 without using arg1
        record.unusable_REGtoREG.append(arg1)
        n2 = n / len(inter_to_arg1_list)
        if (n2 == 0):
            n2 = 1
        for arg2_to_inter in search(QueryType.REGtoREG, inter_reg, arg2, \
                constraint, assertion, n2, clmax=clmax-1, record=record):
            for inter_to_arg1 in inter_to_arg1_list:
                if (len(inter_to_arg1) + len(arg2_to_inter) <= clmax):
                    res.append(arg2_to_inter.addChain(inter_to_arg1, new=True))
                if (len(res) >= n):
                    return res
        record.unusable_REGtoREG.remove(arg1)
    return res
Example #17
0
def initEngine():
    global INIT_LMAX, INIT_MAXDEPTH
    global global_impossible_REGtoREG
    global baseAssertion

    # Init global variables
    baseAssertion = Assertion().add(\
            RegsValidPtrRead([(Arch.spNum(),-5000, 10000)]), copy=False).add(\
            RegsValidPtrWrite([(Arch.spNum(), -5000, 0)]), copy=False)

    info(string_bold("Initializing Semantic Engine\n"))

    # Init helper for REGtoREG
    global_impossible_REGtoREG = SearchEnvironment(INIT_LMAX, Constraint(),
                                                   baseAssertion,
                                                   INIT_MAXDEPTH)
    init_impossible_REGtoREG(global_impossible_REGtoREG)
Example #18
0
def build_mprotect64(addr,
                     size,
                     prot=7,
                     constraint=None,
                     assertion=None,
                     clmax=None,
                     optimizeLen=False):
    """
    Call mprotect from X86-64 arch
    Args must be on registers (rdi, rsi, rdx):
    Sizes are (unsigned long, size_t, unsigned long)
    rax must be 10 
    """
    # Check args
    if not isinstance(addr, int):
        error("Argument error. Expected integer, got " + str(type(addr)))
        return None
    elif not isinstance(size, int):
        error("Argument error. Expected integer, got " + str(type(size)))
        return None
    elif not isinstance(prot, int):
        error("Argument error. Expected integer, got " + str(type(prot)))
        return None

    if (constraint is None):
        constraint = Constraint()
    if (assertion is None):
        assertion = Assertion()

    # Set the registers
    args = [[Arch.n2r('rdi'), addr], [Arch.n2r('rsi'), size],
            [Arch.n2r('rdx'), prot], [Arch.n2r('rax'), 10]]
    chain = popMultiple(args, constraint, assertion, clmax - 1, optimizeLen)
    if (not chain):
        verbose("Failed to set registers for the mprotect syscall")
        return None
    # Syscall
    syscalls = search(QueryType.SYSCALL, None, None, constraint, assertion)
    if (not syscalls):
        verbose("Failed to find a syscall gadget")
        return None
    else:
        chain.addChain(syscalls[0])
    verbose("Success")
    return chain
Example #19
0
 def __init__(self, num, ind=0, size=None):
     Expr.__init__(self)
     if (isinstance(num, SSAReg)):
         self.reg = SSAReg(num.num, num.ind)
     else:
         self.reg = SSAReg(num, ind)
     if (size is None):
         size = Arch.bits()
     self.size = size
Example #20
0
def initScanner(filename):
    global binary_name
    global binary_ELF

    binary_name = filename
    if (Arch.currentIsELF()):
        binary_ELF = ElfParser(binary_name)
    else:
        binary_ELF = None
Example #21
0
def _CSTtoREG_pop(reg,
                  cst,
                  constraint,
                  assertion,
                  n=1,
                  clmax=LMAX,
                  comment=None):
    """
    Returns a payload that puts cst into register reg by poping it from the stack
    """
    # Test clmax
    if (clmax <= 0):
        return []
    # Test n
    if (n < 1):
        return []
    # Check if the cst is incompatible with the constraint
    if (not constraint.badBytes.verifyAddress(cst)):
        return []

    if (not comment):
        comment = "Constant: " + string_bold("0x{:x}".format(cst))

    # Direct pop from the stack
    res = []
    if (reg == Arch.ipNum()):
        constraint2 = constraint.remove([CstrTypeID.CHAINABLE])
    else:
        constraint2 = constraint.add(Chainable(ret=True))
    possible = DBPossiblePopOffsets(reg, constraint2, assertion)
    for offset in sorted(filter(lambda x: x >= 0, possible.keys())):
        # If offsets are too big to fit in the lmax just break
        if (offset > clmax * Arch.octets()):
            break
        # Get possible gadgets
        possible_gadgets = [g for g in possible[offset]\
            if g.spInc >= Arch.octets() \
            and g.spInc - Arch.octets() > offset \
            and (g.spInc/Arch.octets()-1) <= clmax] # Test if padding is too much for clmax
        # Pad the gadgets
        padding = constraint.getValidPadding(Arch.octets())
        for gadget in possible_gadgets:
            chain = ROPChain([gadget])
            for i in range(0, gadget.spInc - Arch.octets(), Arch.octets()):
                if (i == offset):
                    chain.addPadding(cst, comment)
                else:
                    chain.addPadding(padding)
            if (len(chain) <= clmax):
                res.append(chain)
            if (len(res) >= n):
                return res
    return res
Example #22
0
def initScanner(filename):
    global binary_name
    global binary_ELF

    binary_name = filename
    f = open(binary_name, 'rb')

    if (Arch.currentIsELF()):
        binary_ELF = ELFFile(f)
    else:
        binary_ELF = None
Example #23
0
 def addOffset(self, offset):
     new_addresses = []
     for addr in self.addrList:
         new = addr + offset
         # Check if offset isn't too big
         if( new >= (0x1 << Arch.bits())\
             or new < 0 ):
             return False
         new_addresses.append(new)
     self.addrList = new_addresses
     return True
Example #24
0
def find_best_valid_writes(addr, string, constraint, limit=None):
    """
    When using the write strategy, can have bad bytes in addresses too... 
    Try adjust it
    """
    def string_into_reg(string):
        bytes_list = [b for b in string]
        # Get base value
        if (Arch.octets() != len(bytes_list)):
            value = constraint.getValidPadding(Arch.octets() - len(bytes_list))
            if (value is None):
                return None
        else:
            value = 0

        if (Arch.isLittleEndian()):
            tmp = 0
            for byte in reversed(bytes_list):
                value = (value << 8) + ord(byte)
            return value
        elif (Arch.isBigEndian()):
            tmp = 0
            for byte in bytes_list:
                tmp = (tmp << 8) + byte
            return (tmp << (8 * len(bytes_list))) + value
        else:
            return None

    res = []
    tmp_addr = addr
    if (not limit):
        limit = addr + len(string) + 10
    while (tmp_addr + len(string) <= limit):
        res = []
        fail = False
        offset = 0
        while (not fail and offset < len(string)):
            # Get the next write address
            ok = False
            for i in reversed(range(1, Arch.octets() + 1)):
                if (constraint.badBytes.verifyAddress(tmp_addr + offset + i)):
                    ok = True
                    break
            if (not ok):
                fail = True
                break
            else:
                value = string_into_reg(string[offset:i + offset])
                res.append((tmp_addr + offset, value))
                offset += i
        if (not fail):
            return res
    return None
Example #25
0
 def verify(self, gadget):
     for addr in gadget.addrList:
         addrBytes = re.findall('..',format(addr, '0'+str(Arch.octets()*2)+'x'))
         ok = True
         for byte in self.bytes:
             if( byte in addrBytes):
                 ok = False
                 break
         # No bad bytes found, so valid address
         if( ok ):
             return (True, []) 
     return (False, [])
Example #26
0
    def string_into_reg(string):
        bytes_list = [b for b in string]
        # Get base value
        if (Arch.octets() != len(bytes_list)):
            value = constraint.getValidPadding(Arch.octets() - len(bytes_list))
            if (value is None):
                return None
        else:
            value = 0

        if (Arch.isLittleEndian()):
            tmp = 0
            for byte in reversed(bytes_list):
                value = (value << 8) + ord(byte)
            return value
        elif (Arch.isBigEndian()):
            tmp = 0
            for byte in bytes_list:
                tmp = (tmp << 8) + byte
            return (tmp << (8 * len(bytes_list))) + value
        else:
            return None
Example #27
0
def build_mprotect32(addr, size, prot=7, constraint=None, assertion = None, clmax=None, optimizeLen=False):
    """
    Call mprotect from X86 arch
    Args must be on the stack:
    int mprotect(void *addr, size_t len, int prot)
    args must be in registers (ebx, ecx, edx)
    eax must be 0x7d = 125
    """
    # Check args
    if not isinstance(addr, int):
        error("Argument error. Expected integer, got " + str(type(addr)))
        return None
    elif not isinstance(size, int):
        error("Argument error. Expected integer, got " + str(type(size)))
        return None
    elif not isinstance(prot, int):
        error("Argument error. Expected integer, got " + str(type(prot)))
        return None
        
    if( constraint is None ):
        constraint = Constraint()
    if( assertion is None ):
        assertion = Assertion()
    
    # Set the registers
    args = [[Arch.n2r('eax'),0x7d],[Arch.n2r('ebx'), addr],[Arch.n2r('ecx'),size], [Arch.n2r('edx'),prot]]
    chain = popMultiple(args, constraint, assertion, clmax-1, optimizeLen)
    if( not chain ):
        verbose("Failed to set registers for the mprotect syscall")
        return None
    # Int 0x80
    int80_gadgets = search(QueryType.INT80, None, None, constraint, assertion)
    if( not int80_gadgets ):
        verbose("Failed to find an 'int 80' gadget")
        return None
    else:
        chain.addChain(int80_gadgets[0])
    verbose("Success")
    return chain
Example #28
0
def MEMtoREG_transitivity(reg, arg2, constraint, assertion, n=1, clmax=LMAX):
    if (clmax <= 0):
        return []

    res = []
    for inter in range(0, Arch.ssaRegCount):
        if (inter == reg or inter in constraint.getRegsNotModified()
                or inter == Arch.ipNum() or inter == Arch.spNum()):
            continue

        # Find arg1 <- inter
        REGtoREG_record = SearchRecord(maxdepth=4)
        REGtoREG_record.unusable_REGtoREG.append(reg)
        inter_to_reg = search(QueryType.REGtoREG,
                              reg, (inter, 0),
                              constraint,
                              assertion,
                              n,
                              clmax - 1,
                              record=REGtoREG_record)
        if (inter_to_reg):
            len_min = min([len(chain) for chain in inter_to_reg])
            # Try to find inter <- arg2
            # First strategy basic
            arg2_to_inter = _basic(QueryType.MEMtoREG, inter, arg2,
                                   constraint.add(Chainable(ret=True)),
                                   assertion, n, clmax - len_min)
            res += [chain1.addChain(chain2, new=True) for chain1 in arg2_to_inter \
                for chain2 in inter_to_reg if len(chain1)+len(chain2) <= clmax  ]
            # Second strategy read reg (TODO)
            if (len(res) < n):
                pass
        # Did we get enough chains ?
        if (len(res) >= n):
            return res
    # Return the best we got
    return res
Example #29
0
def getFunctionAddress(name):
    """
    Looks for the function 'name' in the PLT of a binary 
    Returns a pair (name, address) as (str, int)
    """
    global binary_name
    global binary_ELF

    if (not Arch.currentIsELF()):
        return (None, None)

    for rela in binary_ELF.jumpRelocationEntries:
        if (rela.symbol.symbolName == name):
            return (rela.symbol.symbolName, rela.r_offset)
    return (None, None)
Example #30
0
def parse_keep_regs(string):
    """
    Parses a 'keep registers' string into a list of register uids
    Input: a string of format like "rax,rcx,rdi"
    Output if valid string (True, list) where list = 
        [1, 3, 4] (R1 is rax, R3 is RCX, ... )
    Output if invalid string (False, error_message)
    """
    user_keep_regs = string.split(',')
    keep_regs = set()
    for reg in user_keep_regs:
        if (reg in Arch.regNameToNum):
            keep_regs.add(Arch.n2r(reg))
        else:
            return (False, "Error. '{}' is not a valid register".format(reg))
    return (True, list(keep_regs))