def verify(self, gadget): if( self.ret == self.jmp == self.call == False ): return (True, []) elif( self.ret and gadget.retType == RetType.RET ): return (True, []) elif( self.jmp and gadget.retType == RetType.JMP ): return (True, []) elif( self.call and gadget.retType == RetType.CALL ): return (True, []) # If unknown ret, check if ret possible # Is a ret sometimes possible ? if( self.ret ): for p in gadget.getSemantics(Arch.ipNum()): if( isinstance(p.expr, MEMExpr)): addr = p.expr.addr (isInc, inc) = addr.isRegIncrement(Arch.spNum()) # Normal ret if the final value of the IP is value that was in memory before the last modification of SP ( i.e final_IP = MEM[final_sp - size_of_a_register ) if( isInc and inc == (gadget.spInc - (Arch.octets())) ): return (True, [p.cond]) # Or a jump ? if( self.jmp ): for p in gadget.getSemantics(Arch.ipNum()): if( isinstance(p.expr, SSAExpr )): return (True, [p.cond]) return (False, [])
def call(funcName, parsedArgs, constraint, assertion): # Get target system if (Arch.currentBinType == Arch.BinaryType.X86_ELF): syscall = Linux32.available.get(funcName) system = sysLinux32 elif (Arch.currentBinType == Arch.BinaryType.X64_ELF): syscall = Linux64.available.get(funcName) system = sysLinux64 else: error("Binary type '{}' not supported yet".format(Arch.currentBinType)) return if (not syscall): error("Syscall '{}' not supported for system '{}'".format(\ funcName, system)) return if (len(parsedArgs) != len(syscall.args)): error("Error. Wrong number of arguments") return # Build syscall res = _build_syscall(syscall.buildFunc, parsedArgs, constraint, assertion) # Print result if (not res): print(string_bold("\n\tNo matching ROPChain found")) else: print(string_bold("\n\tFound matching ROPChain\n")) badBytes = constraint.getBadBytes() if (OUTPUT == OUTPUT_CONSOLE): print(res.strConsole(Arch.bits(), badBytes)) elif (OUTPUT == OUTPUT_PYTHON): print(res.strPython(Arch.bits(), badBytes))
def CSTtoMEM_write(arg1, cst, constraint, assertion, n=1, clmax=LMAX): """ reg <- cst mem(arg2) <- reg """ if (clmax <= 0): return [] res = [] addr_reg = arg1[0] addr_cst = arg1[1] # 1. First strategy (direct) # reg <- cst # mem(arg1) <- reg for reg in range(0, Arch.ssaRegCount): if (reg == Arch.ipNum() or reg == Arch.spNum() or reg == addr_reg): continue # Find reg <- cst # maxdepth 3 or it's too slow cst_to_reg_chains = search(QueryType.CSTtoREG, reg, cst, constraint.add(RegsNotModified([addr_reg])), assertion, n, clmax - 1, maxdepth=3) if (not cst_to_reg_chains): continue # Search for mem(arg1) <- reg # We get all reg2,cst2 s.t mem(arg1) <- reg2+cst2 possible_mem_writes = DBPossibleMemWrites(addr_reg, addr_cst, constraint, assertion, n=1) # 1.A. Ideally we look for reg2=reg and cst2=0 (direct_writes) possible_mem_writes_reg = possible_mem_writes.get(reg) if (possible_mem_writes_reg): direct_writes = possible_mem_writes[reg].get(0, []) else: direct_writes = [] padding = constraint.getValidPadding(Arch.octets()) for write_gadget in direct_writes: for cst_to_reg_chain in cst_to_reg_chains: # Pad the gadgets write_chain = ROPChain([write_gadget]) for i in range(0, write_gadget.spInc - Arch.octets(), Arch.octets()): write_chain.addPadding(padding) full_chain = cst_to_reg_chain.addChain(write_chain, new=True) if (len(full_chain) <= clmax): res.append(full_chain) if (len(res) >= n): return res # 1.B. return res
def _CSTtoREG_transitivity(reg, cst, env, n=1): """ Perform REG1 <- CST with REG1 <- REG2 <- CST """ ID = StrategyType.CSTtoREG_TRANSITIVITY ## Test for special cases # Test lmax if (env.getLmax() <= 0): return [] # Limit number of calls to ... elif (env.nbCalls(ID) >= 99): return [] # Check if the cst is in badBytes elif (not env.getConstraint().badBytes.verifyAddress(cst)): return [] # Check if previous call was already CSTtoREG_transitivity # Reason: we handle the transitivity with REGtoREG transitivity # so no need to do it also recursively with this one ;) elif (env.callsHistory()[-1] == ID): return [] # Set env env.addCall(ID) ############################# res = [] for inter in Arch.registers(): if (inter == reg or inter in env.getConstraint().getRegsNotModified() or inter == Arch.ipNum() or inter == Arch.spNum()): continue # Find reg <- inter inter_to_reg = _search(QueryType.REGtoREG, reg, (inter, 0), env, n) if (inter_to_reg): # We found ROPChains s.t reg <- inter # Now we want inter <- cst min_len = min([len(chain) for chain in inter_to_reg]) env.subLmax(min_len) env.addUnusableReg(reg) cst_to_inter = _search(QueryType.CSTtoREG, inter, cst, env, n / len(inter_to_reg) + 1) env.removeUnusableReg(reg) env.addLmax(min_len) for chain2 in inter_to_reg: for chain1 in cst_to_inter: if (len(chain1) + len(chain2) <= env.getLmax()): res.append(chain1.addChain(chain2, new=True)) # Did we get enough chains ? if (len(res) >= n): break ############################### # Restore env env.removeCall(ID) return res[:n]
def _MEMtoREG_transitivity(reg, arg2, env, n=1): """ Perform reg <- inter <- mem(arg2) """ ID = StrategyType.MEMtoREG_TRANSITIVITY ## Test for special cases # Test lmax if (env.getLmax() <= 0): return FailRecord(lmax=True) # Limit number of calls to ... elif (env.nbCalls(ID) >= 99): return FailRecord() # Check if previous call was already MEMtoREG_transitivity # Reason: we handle the transitivity with REGtoREG transitivity # so no need to do it also recursively with this one ;) elif (env.callsHistory()[-1] == ID): return FailRecord() # Set env env.addCall(ID) ########################### res = [] res_fail = FailRecord() for inter in Arch.registers(): if (inter == reg or inter in env.getConstraint().getRegsNotModified() or inter == Arch.ipNum() or inter == Arch.spNum()): continue # Find arg1 <- inter inter_to_reg = _search(QueryType.REGtoREG, reg, (inter, 0), env, n) if (inter_to_reg): min_len = min([len(chain) for chain in inter_to_reg]) # Try to find inter <- arg2 env.subLmax(min_len) env.addUnusableReg(reg) arg2_to_inter = _search(QueryType.MEMtoREG, inter, arg2, env, n) env.removeUnusableReg(reg) env.addLmax(min_len) if (not arg2_to_inter): res_fail.merge(arg2_to_inter) continue res += [chain1.addChain(chain2, new=True) for chain1 in arg2_to_inter \ for chain2 in inter_to_reg if len(chain1)+len(chain2) <= env.getLmax() ] else: res_fail.merge(inter_to_reg) # Did we get enough chains ? if (len(res) >= n): break ######################## # Restore env env.removeCall(ID) return res[:n] if res else res_fail
def _CSTtoREG_transitivity(reg, cst, constraint, assertion, n=1, clmax=LMAX, comment=None, maxdepth=4): """ Perform REG1 <- CST with REG1 <- REG2 <- CST """ # Test clmax if (clmax <= 0): return [] res = [] for inter in range(0, Arch.ssaRegCount): if (inter == reg or inter in constraint.getRegsNotModified() or inter == Arch.ipNum() or inter == Arch.spNum()): continue # Find reg <- inter REGtoREG_record = SearchRecord(maxdepth=maxdepth) REGtoREG_record.unusable_REGtoREG.append(reg) inter_to_reg = search(QueryType.REGtoREG, reg, (inter, 0), constraint, assertion, n, clmax, record=REGtoREG_record) if (inter_to_reg): # We found ROPChains s.t reg <- inter # Now we want inter <- cst cst_to_inter = _basic(QueryType.CSTtoREG, inter, cst, constraint, assertion, n / len(inter_to_reg) + 1, clmax - 1) for chain2 in inter_to_reg: for chain1 in cst_to_inter: if (len(chain1) + len(chain2) <= clmax): res.append(chain1.addChain(chain2, new=True)) if (len(res) < n): cst_to_inter = _CSTtoREG_pop(inter, cst, constraint, assertion, n / (len(inter_to_reg)) + 1, clmax - 1, comment) for chain2 in inter_to_reg: for chain1 in cst_to_inter: if (len(chain1) + len(chain2) <= clmax): res.append(chain1.addChain(chain2, new=True)) # Did we get enough chains ? if (len(res) >= n): return res[:n] # Return what we got return res
def _REGtoMEM_transitivity(arg1, arg2, env, n=1): """ reg <- arg2 mem(arg1) <- reg """ ID = StrategyType.REGtoMEM_TRANSITIVITY ## Test for special cases # Test lmax if (env.getLmax() <= 0): return FailRecord(lmax=True) # Limit number of calls to ... elif (env.nbCalls(ID) >= 99): return FailRecord() # Check if previous call was already REGtoMEM_transitivity # Reason: we handle the transitivity with REGtoREG transitivity # so no need to do it also recursively with this one ;) elif (env.callsHistory()[-1] == ID): return FailRecord() # Set env env.addCall(ID) ################################### res = [] res_fail = FailRecord() for inter in Arch.registers(): if (inter == arg2[0] or inter in env.getConstraint().getRegsNotModified() or inter == Arch.ipNum() or inter == Arch.spNum()): continue # Find inter <- arg2 arg2_to_inter = _search(QueryType.REGtoREG, inter, (arg2[0], arg2[1]), env, n) if (arg2_to_inter): len_min = min([len(chain) for chain in arg2_to_inter]) # Try to find mem(arg1) <- inter env.subLmax(len_min) env.addUnusableReg(arg2[0]) inter_to_mem = _search(QueryType.REGtoMEM, arg1, (inter, 0), env) env.removeUnusableReg(arg2[0]) env.addLmax(len_min) if (not inter_to_mem): res_fail.merge(inter_to_mem) continue res += [chain1.addChain(chain2, new=True) for chain1 in arg2_to_inter\ for chain2 in inter_to_mem if len(chain1)+len(chain2) <= env.getLmax()] if (len(res) >= n): break else: res_fail.merge(arg2_to_inter) ##################################### # Resotre env env.removeCall(ID) return res if res else res_fail
def build_mprotect64(addr, size, prot=7, constraint=None, assertion=None, clmax=SYSCALL_LMAX, optimizeLen=False): """ Call mprotect from X86-64 arch Args must be on registers (rdi, rsi, rdx): Sizes are (unsigned long, size_t, unsigned long) rax must be 10 """ # Check args if not isinstance(addr, int): error("Argument error. Expected integer, got " + str(type(addr))) return None elif not isinstance(size, int): error("Argument error. Expected integer, got " + str(type(size))) return None elif not isinstance(prot, int): error("Argument error. Expected integer, got " + str(type(prot))) return None if( constraint is None ): constraint = Constraint() if( assertion is None ): assertion = Assertion() # Check if we have the function ! verbose("Trying to call mprotect() function directly") func_call = build_call('mprotect', [addr, size, prot], constraint, assertion, clmax=clmax, optimizeLen=optimizeLen) if( not isinstance(func_call, str) ): verbose("Success") return func_call else: if( not constraint.chainable.ret ): verbose("Coudn't call mprotect(), try direct syscall") else: verbose("Couldn't call mprotect() and return to ROPChain") return None # Otherwise do the syscall by 'hand' # Set the registers args = [[Arch.n2r('rdi'),addr],[Arch.n2r('rsi'), size],[Arch.n2r('rdx'),prot], [Arch.n2r('rax'),10]] chain = popMultiple(args, constraint, assertion, clmax-1, optimizeLen) if( not chain ): verbose("Failed to set registers for the mprotect syscall") return None # Syscall syscalls = search(QueryType.SYSCALL, None, None, constraint, assertion) if( not syscalls ): verbose("Failed to find a syscall gadget") return None else: chain.addChain(syscalls[0]) verbose("Success") return chain
def build_call(funcName, funcArgs, constraint, assertion): # Find the address of the fonction (funcName2, funcAddr) = getFunctionAddress(funcName) if (funcName2 is None): return "Couldn't find function '{}' in the binary".format(funcName) # Check if bad bytes in function address if (not constraint.badBytes.verifyAddress(funcAddr)): return "'{}' address ({}) contains bad bytes".format( funcName2, string_special('0x' + format(funcAddr, '0' + str(Arch.octets() * 2) + 'x'))) # Find a gadget for the fake return address offset = len( funcArgs) * 8 - 8 # Because we do +8 at the beginning of the loop skip_args_chains = [] i = 4 while (i > 0 and (not skip_args_chains)): offset += 8 skip_args_chains = search(QueryType.MEMtoREG, Arch.ipNum(), \ (Arch.spNum(),offset), constraint, assertion, n=1) i -= 1 if (not skip_args_chains): return "Couldn't build ROP-Chain" skip_args_chain = skip_args_chains[0] # Build the ropchain with the arguments args_chain = ROPChain() arg_n = len(funcArgs) for arg in reversed(funcArgs): if (isinstance(arg, int)): args_chain.addPadding(arg, comment="Arg{}: {}".format( arg_n, string_ropg(hex(arg)))) arg_n -= 1 else: return "Type of argument '{}' not supported yet :'(".format(arg) # Build call chain (function address + fake return address) call_chain = ROPChain() call_chain.addPadding(funcAddr, comment=string_ropg(funcName2)) skip_args_addr = int( validAddrStr(skip_args_chain.chain[0], constraint.getBadBytes(), Arch.bits()), 16) call_chain.addPadding(skip_args_addr, comment="Address of: " + string_bold(str(skip_args_chain.chain[0]))) return call_chain.addChain(args_chain)
def build_call_linux86(funcName, funcArgs, constraint, assertion, clmax=None, optimizeLen=False): # Find the address of the fonction (funcName2, funcAddr) = getFunctionAddress(funcName) if( funcName2 is None ): return "Couldn't find function '{}' in the binary".format(funcName) # Check if bad bytes in function address if( not constraint.badBytes.verifyAddress(funcAddr) ): return "'{}' address ({}) contains bad bytes".format(funcName2, string_special('0x'+format(funcAddr, '0'+str(Arch.octets()*2)+'x'))) # Check if lmax too small if( (1 + len(funcArgs) + (lambda x: 1 if len(x)>0 else 0)(funcArgs)) > clmax ): return "Not enough bytes to call function '{}'".format(funcName) # Find a gadget for the fake return address if( funcArgs ): offset = (len(funcArgs)-1)*Arch.octets() # Because we do +octets() at the beginning of the loop skip_args_chains = [] i = 4 # Try 4 more maximum while( i > 0 and (not skip_args_chains)): offset += Arch.octets() skip_args_chains = search(QueryType.MEMtoREG, Arch.ipNum(), \ (Arch.spNum(),offset), constraint, assertion, n=1, optimizeLen=optimizeLen) i -= 1 if( not skip_args_chains ): return "Couldn't build ROP-Chain" skip_args_chain = skip_args_chains[0] else: # No arguments skip_args_chain = None # Build the ropchain with the arguments args_chain = ROPChain() arg_n = len(funcArgs) for arg in reversed(funcArgs): if( isinstance(arg, int) ): args_chain.addPadding(arg, comment="Arg{}: {}".format(arg_n, string_ropg(hex(arg)))) arg_n -= 1 else: return "Type of argument '{}' not supported yet :'(".format(arg) # Build call chain (function address + fake return address) call_chain = ROPChain() call_chain.addPadding(funcAddr, comment=string_ropg(funcName2)) if( funcArgs ): skip_args_addr = int( validAddrStr(skip_args_chain.chain[0], constraint.getBadBytes(), Arch.bits()) ,16) call_chain.addPadding(skip_args_addr, comment="Address of: "+string_bold(str(skip_args_chain.chain[0]))) return call_chain.addChain(args_chain)
def init_impossible_REGtoREG(env): global INIT_LMAX, INIT_MAXDEPTH global baseAssertion try: startTime = datetime.now() i = 0 impossible_count = 0 for reg1 in sorted(Arch.registers()): reg_name = Arch.r2n(reg1) if (len(reg_name) < 6): reg_name += " " * (6 - len(reg_name)) elif (len(reg_name) >= 6): reg_name = reg_name[:5] + "." for reg2 in Arch.registers(): i += 1 charging_bar(len(Arch.registers() * len(Arch.registers())), i, 30) if (reg2 == reg1 or reg2 == Arch.ipNum()): continue _search(QueryType.REGtoREG, reg1, (reg2, 0), env, n=1) if (env.checkImpossible_REGtoREG(reg1, reg2, 0)): impossible_count += 1 cTime = datetime.now() - startTime # Get how many impossible path we found impossible_rate = int(100 * (float(impossible_count) / float( (len(Arch.registers()) - 1) * len(Arch.registers())))) notify('Optimization rate : {}%'.format(impossible_rate)) notify("Computation time : " + str(cTime)) except: print("\n") fatal("Exception caught, stopping Semantic Engine init process...\n") fatal("Search time might get very long !\n") env = SearchEnvironment(INIT_LMAX, Constraint(), baseAssertion, INIT_MAXDEPTH)
def init_impossible_REGtoREG(env): global INIT_LMAX, INIT_MAXDEPTH global baseAssertion # DEBUG #try: startTime = datetime.now() i = 0 impossible_count = 0 for reg1 in sorted(Arch.registers()): reg_name = Arch.r2n(reg1) if (len(reg_name) < 6): reg_name += " " * (6 - len(reg_name)) elif (len(reg_name) >= 6): reg_name = reg_name[:5] + "." for reg2 in Arch.registers(): i += 1 charging_bar(len(Arch.registers() * len(Arch.registers())), i, 30) if (reg2 == reg1 or reg2 == Arch.ipNum()): continue _search(QueryType.REGtoREG, reg1, (reg2, 0), env, n=1) if (env.checkImpossible_REGtoREG(reg1, reg2, 0)): impossible_count += 1 cTime = datetime.now() - startTime # Get how many impossible path we found impossible_rate = int(100 * (float(impossible_count) / float( (len(Arch.registers()) - 1) * len(Arch.registers())))) notify('Optimization rate : {}%'.format(impossible_rate)) notify("Computation time : " + str(cTime))
def build_call_linux64(funcName, funcArgs, constraint, assertion, clmax=None, optimizeLen=False): # Arguments registers # (Args should go in these registers for x64) argsRegsNames = ['rdi','rsi','rdx','rcx', 'r8', 'r9'] argsRegs = [Arch.n2r(name) for name in argsRegsNames] # Find the address of the fonction (funcName2, funcAddr) = getFunctionAddress(funcName) if( funcName2 is None ): return "Couldn't find function '{}' in the binary".format(funcName) # Check if bad bytes in function address if( not constraint.badBytes.verifyAddress(funcAddr) ): return "'{}' address ({}) contains bad bytes".format(funcName2, string_special('0x'+format(funcAddr, '0'+str(Arch.octets()*2)+'x'))) # Check how many arguments if( len(funcArgs) > 6 ): return "Doesn't support function call with more than 6 arguments with Linux X64 calling convention :(" # Find a gadget for the fake return address if( funcArgs ): # Build the ropchain with the arguments args_chain = popMultiple(map(lambda x,y:(x,)+y, argsRegs[:len(funcArgs)], funcArgs), constraint, assertion, clmax=clmax, optimizeLen=optimizeLen) if( not args_chain): return "Couldn't load arguments in registers" else: # No arguments args_chain = ROPChain() # Build call chain (function address + fake return address) return args_chain.addPadding(funcAddr, comment=string_ropg(funcName2))
def getFunctionAddress(name): """ Looks for the function 'name' in the PLT of a binary Returns a pair (name, address) as (str, int) """ global binary_name global binary_ELF if (not Arch.currentIsELF()): return (None, None) # Get function in relocatins relasec_name = '.rela.plt' relasec = binary_ELF.get_section_by_name(relasec_name) if not isinstance(relasec, RelocationSection): print(' ERROR DEBUG The file has no %s section' % relasec_name) relasec_addr = relasec.header['sh_addr'] symbols = binary_ELF.get_section(relasec.header['sh_link']) if (not isinstance(symbols, NullSection)): for reloc in relasec.iter_relocations(): if (symbols.get_symbol(reloc['r_info_sym']).name == name): return (name, reloc['r_offset'] + relasec_addr) # Get function from symbol table sections for symsec in getSymbolSections(): function = symsec.get_symbol_by_name(name) if (function): return (name, function[0]['st_value']) return (None, None)
def find(args): """ args - List of user arguments as strings (the command should not be included in the list as args[0]) """ if (not args): print_help() return if (args[0] == OPTION_HELP or args[0] == OPTION_HELP_SHORT): print_help() return parsed_args = parse_args(args) if (not parsed_args[0]): error(parsed_args[1]) else: qtype = parsed_args[1] arg1 = parsed_args[2] arg2 = parsed_args[3] constraint = parsed_args[4] nbResults = parsed_args[5] clmax = parsed_args[6] assertion = Assertion().add(\ RegsValidPtrRead([(Arch.spNum(),-5000, 10000)])).add(\ RegsValidPtrWrite([(Arch.spNum(), -5000, 0)])) # Search res = search(qtype, arg1, arg2, constraint, assertion, n=nbResults, clmax=clmax) if (res): print_chains(res, "Built matching ROPChain(s)", constraint.getBadBytes()) else: res = search_not_chainable(qtype, arg1, arg2, constraint, assertion, n=nbResults, clmax=clmax) print_chains(res, "Possibly matching gadget(s)", constraint.getBadBytes())
def _REGtoREG_transitivity(arg1, arg2, constraint, assertion, record, n=1, clmax=LMAX): """ Perform REG1 <- REG2+CST with REG1 <- REG3 <- REG2+CST """ # Test clmax if (clmax <= 0): return [] # If reg1 <- reg1 + 0, return if (arg1 == arg2[0] and arg2[1] == 0): return [] res = [] for inter_reg in range(0, Arch.ssaRegCount): if( inter_reg == arg1 or (inter_reg == arg2[0] and arg2[1]==0)\ or (inter_reg in record.unusable_REGtoREG) or inter_reg == Arch.ipNum()\ or (inter_reg == Arch.spNum()) ): continue # Find reg1 <- inter_reg without using arg2 record.unusable_REGtoREG.append(arg2[0]) inter_to_arg1_list = search(QueryType.REGtoREG, arg1, (inter_reg, 0), \ constraint, assertion, n, clmax=clmax-1, record=record ) record.unusable_REGtoREG.remove(arg2[0]) if (not inter_to_arg1_list): continue # Find inter_reg <- arg2 without using arg1 record.unusable_REGtoREG.append(arg1) n2 = n / len(inter_to_arg1_list) if (n2 == 0): n2 = 1 for arg2_to_inter in search(QueryType.REGtoREG, inter_reg, arg2, \ constraint, assertion, n2, clmax=clmax-1, record=record): for inter_to_arg1 in inter_to_arg1_list: if (len(inter_to_arg1) + len(arg2_to_inter) <= clmax): res.append(arg2_to_inter.addChain(inter_to_arg1, new=True)) if (len(res) >= n): return res record.unusable_REGtoREG.remove(arg1) return res
def initEngine(): global INIT_LMAX, INIT_MAXDEPTH global global_impossible_REGtoREG global baseAssertion # Init global variables baseAssertion = Assertion().add(\ RegsValidPtrRead([(Arch.spNum(),-5000, 10000)]), copy=False).add(\ RegsValidPtrWrite([(Arch.spNum(), -5000, 0)]), copy=False) info(string_bold("Initializing Semantic Engine\n")) # Init helper for REGtoREG global_impossible_REGtoREG = SearchEnvironment(INIT_LMAX, Constraint(), baseAssertion, INIT_MAXDEPTH) init_impossible_REGtoREG(global_impossible_REGtoREG)
def build_mprotect64(addr, size, prot=7, constraint=None, assertion=None, clmax=None, optimizeLen=False): """ Call mprotect from X86-64 arch Args must be on registers (rdi, rsi, rdx): Sizes are (unsigned long, size_t, unsigned long) rax must be 10 """ # Check args if not isinstance(addr, int): error("Argument error. Expected integer, got " + str(type(addr))) return None elif not isinstance(size, int): error("Argument error. Expected integer, got " + str(type(size))) return None elif not isinstance(prot, int): error("Argument error. Expected integer, got " + str(type(prot))) return None if (constraint is None): constraint = Constraint() if (assertion is None): assertion = Assertion() # Set the registers args = [[Arch.n2r('rdi'), addr], [Arch.n2r('rsi'), size], [Arch.n2r('rdx'), prot], [Arch.n2r('rax'), 10]] chain = popMultiple(args, constraint, assertion, clmax - 1, optimizeLen) if (not chain): verbose("Failed to set registers for the mprotect syscall") return None # Syscall syscalls = search(QueryType.SYSCALL, None, None, constraint, assertion) if (not syscalls): verbose("Failed to find a syscall gadget") return None else: chain.addChain(syscalls[0]) verbose("Success") return chain
def __init__(self, num, ind=0, size=None): Expr.__init__(self) if (isinstance(num, SSAReg)): self.reg = SSAReg(num.num, num.ind) else: self.reg = SSAReg(num, ind) if (size is None): size = Arch.bits() self.size = size
def initScanner(filename): global binary_name global binary_ELF binary_name = filename if (Arch.currentIsELF()): binary_ELF = ElfParser(binary_name) else: binary_ELF = None
def _CSTtoREG_pop(reg, cst, constraint, assertion, n=1, clmax=LMAX, comment=None): """ Returns a payload that puts cst into register reg by poping it from the stack """ # Test clmax if (clmax <= 0): return [] # Test n if (n < 1): return [] # Check if the cst is incompatible with the constraint if (not constraint.badBytes.verifyAddress(cst)): return [] if (not comment): comment = "Constant: " + string_bold("0x{:x}".format(cst)) # Direct pop from the stack res = [] if (reg == Arch.ipNum()): constraint2 = constraint.remove([CstrTypeID.CHAINABLE]) else: constraint2 = constraint.add(Chainable(ret=True)) possible = DBPossiblePopOffsets(reg, constraint2, assertion) for offset in sorted(filter(lambda x: x >= 0, possible.keys())): # If offsets are too big to fit in the lmax just break if (offset > clmax * Arch.octets()): break # Get possible gadgets possible_gadgets = [g for g in possible[offset]\ if g.spInc >= Arch.octets() \ and g.spInc - Arch.octets() > offset \ and (g.spInc/Arch.octets()-1) <= clmax] # Test if padding is too much for clmax # Pad the gadgets padding = constraint.getValidPadding(Arch.octets()) for gadget in possible_gadgets: chain = ROPChain([gadget]) for i in range(0, gadget.spInc - Arch.octets(), Arch.octets()): if (i == offset): chain.addPadding(cst, comment) else: chain.addPadding(padding) if (len(chain) <= clmax): res.append(chain) if (len(res) >= n): return res return res
def initScanner(filename): global binary_name global binary_ELF binary_name = filename f = open(binary_name, 'rb') if (Arch.currentIsELF()): binary_ELF = ELFFile(f) else: binary_ELF = None
def addOffset(self, offset): new_addresses = [] for addr in self.addrList: new = addr + offset # Check if offset isn't too big if( new >= (0x1 << Arch.bits())\ or new < 0 ): return False new_addresses.append(new) self.addrList = new_addresses return True
def find_best_valid_writes(addr, string, constraint, limit=None): """ When using the write strategy, can have bad bytes in addresses too... Try adjust it """ def string_into_reg(string): bytes_list = [b for b in string] # Get base value if (Arch.octets() != len(bytes_list)): value = constraint.getValidPadding(Arch.octets() - len(bytes_list)) if (value is None): return None else: value = 0 if (Arch.isLittleEndian()): tmp = 0 for byte in reversed(bytes_list): value = (value << 8) + ord(byte) return value elif (Arch.isBigEndian()): tmp = 0 for byte in bytes_list: tmp = (tmp << 8) + byte return (tmp << (8 * len(bytes_list))) + value else: return None res = [] tmp_addr = addr if (not limit): limit = addr + len(string) + 10 while (tmp_addr + len(string) <= limit): res = [] fail = False offset = 0 while (not fail and offset < len(string)): # Get the next write address ok = False for i in reversed(range(1, Arch.octets() + 1)): if (constraint.badBytes.verifyAddress(tmp_addr + offset + i)): ok = True break if (not ok): fail = True break else: value = string_into_reg(string[offset:i + offset]) res.append((tmp_addr + offset, value)) offset += i if (not fail): return res return None
def verify(self, gadget): for addr in gadget.addrList: addrBytes = re.findall('..',format(addr, '0'+str(Arch.octets()*2)+'x')) ok = True for byte in self.bytes: if( byte in addrBytes): ok = False break # No bad bytes found, so valid address if( ok ): return (True, []) return (False, [])
def string_into_reg(string): bytes_list = [b for b in string] # Get base value if (Arch.octets() != len(bytes_list)): value = constraint.getValidPadding(Arch.octets() - len(bytes_list)) if (value is None): return None else: value = 0 if (Arch.isLittleEndian()): tmp = 0 for byte in reversed(bytes_list): value = (value << 8) + ord(byte) return value elif (Arch.isBigEndian()): tmp = 0 for byte in bytes_list: tmp = (tmp << 8) + byte return (tmp << (8 * len(bytes_list))) + value else: return None
def build_mprotect32(addr, size, prot=7, constraint=None, assertion = None, clmax=None, optimizeLen=False): """ Call mprotect from X86 arch Args must be on the stack: int mprotect(void *addr, size_t len, int prot) args must be in registers (ebx, ecx, edx) eax must be 0x7d = 125 """ # Check args if not isinstance(addr, int): error("Argument error. Expected integer, got " + str(type(addr))) return None elif not isinstance(size, int): error("Argument error. Expected integer, got " + str(type(size))) return None elif not isinstance(prot, int): error("Argument error. Expected integer, got " + str(type(prot))) return None if( constraint is None ): constraint = Constraint() if( assertion is None ): assertion = Assertion() # Set the registers args = [[Arch.n2r('eax'),0x7d],[Arch.n2r('ebx'), addr],[Arch.n2r('ecx'),size], [Arch.n2r('edx'),prot]] chain = popMultiple(args, constraint, assertion, clmax-1, optimizeLen) if( not chain ): verbose("Failed to set registers for the mprotect syscall") return None # Int 0x80 int80_gadgets = search(QueryType.INT80, None, None, constraint, assertion) if( not int80_gadgets ): verbose("Failed to find an 'int 80' gadget") return None else: chain.addChain(int80_gadgets[0]) verbose("Success") return chain
def MEMtoREG_transitivity(reg, arg2, constraint, assertion, n=1, clmax=LMAX): if (clmax <= 0): return [] res = [] for inter in range(0, Arch.ssaRegCount): if (inter == reg or inter in constraint.getRegsNotModified() or inter == Arch.ipNum() or inter == Arch.spNum()): continue # Find arg1 <- inter REGtoREG_record = SearchRecord(maxdepth=4) REGtoREG_record.unusable_REGtoREG.append(reg) inter_to_reg = search(QueryType.REGtoREG, reg, (inter, 0), constraint, assertion, n, clmax - 1, record=REGtoREG_record) if (inter_to_reg): len_min = min([len(chain) for chain in inter_to_reg]) # Try to find inter <- arg2 # First strategy basic arg2_to_inter = _basic(QueryType.MEMtoREG, inter, arg2, constraint.add(Chainable(ret=True)), assertion, n, clmax - len_min) res += [chain1.addChain(chain2, new=True) for chain1 in arg2_to_inter \ for chain2 in inter_to_reg if len(chain1)+len(chain2) <= clmax ] # Second strategy read reg (TODO) if (len(res) < n): pass # Did we get enough chains ? if (len(res) >= n): return res # Return the best we got return res
def getFunctionAddress(name): """ Looks for the function 'name' in the PLT of a binary Returns a pair (name, address) as (str, int) """ global binary_name global binary_ELF if (not Arch.currentIsELF()): return (None, None) for rela in binary_ELF.jumpRelocationEntries: if (rela.symbol.symbolName == name): return (rela.symbol.symbolName, rela.r_offset) return (None, None)
def parse_keep_regs(string): """ Parses a 'keep registers' string into a list of register uids Input: a string of format like "rax,rcx,rdi" Output if valid string (True, list) where list = [1, 3, 4] (R1 is rax, R3 is RCX, ... ) Output if invalid string (False, error_message) """ user_keep_regs = string.split(',') keep_regs = set() for reg in user_keep_regs: if (reg in Arch.regNameToNum): keep_regs.add(Arch.n2r(reg)) else: return (False, "Error. '{}' is not a valid register".format(reg)) return (True, list(keep_regs))