def verify(self, gadget): if( self.ret == self.jmp == self.call == False ): return (True, []) elif( self.ret and gadget.retType == RetType.RET ): return (True, []) elif( self.jmp and gadget.retType == RetType.JMP ): return (True, []) elif( self.call and gadget.retType == RetType.CALL ): return (True, []) # If unknown ret, check if ret possible # Is a ret sometimes possible ? if( self.ret ): for p in gadget.getSemantics(Arch.ipNum()): if( isinstance(p.expr, MEMExpr)): addr = p.expr.addr (isInc, inc) = addr.isRegIncrement(Arch.spNum()) # Normal ret if the final value of the IP is value that was in memory before the last modification of SP ( i.e final_IP = MEM[final_sp - size_of_a_register ) if( isInc and inc == (gadget.spInc - (Arch.octets())) ): return (True, [p.cond]) # Or a jump ? if( self.jmp ): for p in gadget.getSemantics(Arch.ipNum()): if( isinstance(p.expr, SSAExpr )): return (True, [p.cond]) return (False, [])
def init_impossible_REGtoREG(env): global INIT_LMAX, INIT_MAXDEPTH global baseAssertion # DEBUG #try: startTime = datetime.now() i = 0 impossible_count = 0 for reg1 in sorted(Arch.registers()): reg_name = Arch.r2n(reg1) if (len(reg_name) < 6): reg_name += " " * (6 - len(reg_name)) elif (len(reg_name) >= 6): reg_name = reg_name[:5] + "." for reg2 in Arch.registers(): i += 1 charging_bar(len(Arch.registers() * len(Arch.registers())), i, 30) if (reg2 == reg1 or reg2 == Arch.ipNum()): continue _search(QueryType.REGtoREG, reg1, (reg2, 0), env, n=1) if (env.checkImpossible_REGtoREG(reg1, reg2, 0)): impossible_count += 1 cTime = datetime.now() - startTime # Get how many impossible path we found impossible_rate = int(100 * (float(impossible_count) / float( (len(Arch.registers()) - 1) * len(Arch.registers())))) notify('Optimization rate : {}%'.format(impossible_rate)) notify("Computation time : " + str(cTime))
def init_impossible_REGtoREG(env): global INIT_LMAX, INIT_MAXDEPTH global baseAssertion try: startTime = datetime.now() i = 0 impossible_count = 0 for reg1 in sorted(Arch.registers()): reg_name = Arch.r2n(reg1) if (len(reg_name) < 6): reg_name += " " * (6 - len(reg_name)) elif (len(reg_name) >= 6): reg_name = reg_name[:5] + "." for reg2 in Arch.registers(): i += 1 charging_bar(len(Arch.registers() * len(Arch.registers())), i, 30) if (reg2 == reg1 or reg2 == Arch.ipNum()): continue _search(QueryType.REGtoREG, reg1, (reg2, 0), env, n=1) if (env.checkImpossible_REGtoREG(reg1, reg2, 0)): impossible_count += 1 cTime = datetime.now() - startTime # Get how many impossible path we found impossible_rate = int(100 * (float(impossible_count) / float( (len(Arch.registers()) - 1) * len(Arch.registers())))) notify('Optimization rate : {}%'.format(impossible_rate)) notify("Computation time : " + str(cTime)) except: print("\n") fatal("Exception caught, stopping Semantic Engine init process...\n") fatal("Search time might get very long !\n") env = SearchEnvironment(INIT_LMAX, Constraint(), baseAssertion, INIT_MAXDEPTH)
def _CSTtoREG_transitivity(reg, cst, env, n=1): """ Perform REG1 <- CST with REG1 <- REG2 <- CST """ ID = StrategyType.CSTtoREG_TRANSITIVITY ## Test for special cases # Test lmax if (env.getLmax() <= 0): return [] # Limit number of calls to ... elif (env.nbCalls(ID) >= 99): return [] # Check if the cst is in badBytes elif (not env.getConstraint().badBytes.verifyAddress(cst)): return [] # Check if previous call was already CSTtoREG_transitivity # Reason: we handle the transitivity with REGtoREG transitivity # so no need to do it also recursively with this one ;) elif (env.callsHistory()[-1] == ID): return [] # Set env env.addCall(ID) ############################# res = [] for inter in Arch.registers(): if (inter == reg or inter in env.getConstraint().getRegsNotModified() or inter == Arch.ipNum() or inter == Arch.spNum()): continue # Find reg <- inter inter_to_reg = _search(QueryType.REGtoREG, reg, (inter, 0), env, n) if (inter_to_reg): # We found ROPChains s.t reg <- inter # Now we want inter <- cst min_len = min([len(chain) for chain in inter_to_reg]) env.subLmax(min_len) env.addUnusableReg(reg) cst_to_inter = _search(QueryType.CSTtoREG, inter, cst, env, n / len(inter_to_reg) + 1) env.removeUnusableReg(reg) env.addLmax(min_len) for chain2 in inter_to_reg: for chain1 in cst_to_inter: if (len(chain1) + len(chain2) <= env.getLmax()): res.append(chain1.addChain(chain2, new=True)) # Did we get enough chains ? if (len(res) >= n): break ############################### # Restore env env.removeCall(ID) return res[:n]
def CSTtoMEM_write(arg1, cst, constraint, assertion, n=1, clmax=LMAX): """ reg <- cst mem(arg2) <- reg """ if (clmax <= 0): return [] res = [] addr_reg = arg1[0] addr_cst = arg1[1] # 1. First strategy (direct) # reg <- cst # mem(arg1) <- reg for reg in range(0, Arch.ssaRegCount): if (reg == Arch.ipNum() or reg == Arch.spNum() or reg == addr_reg): continue # Find reg <- cst # maxdepth 3 or it's too slow cst_to_reg_chains = search(QueryType.CSTtoREG, reg, cst, constraint.add(RegsNotModified([addr_reg])), assertion, n, clmax - 1, maxdepth=3) if (not cst_to_reg_chains): continue # Search for mem(arg1) <- reg # We get all reg2,cst2 s.t mem(arg1) <- reg2+cst2 possible_mem_writes = DBPossibleMemWrites(addr_reg, addr_cst, constraint, assertion, n=1) # 1.A. Ideally we look for reg2=reg and cst2=0 (direct_writes) possible_mem_writes_reg = possible_mem_writes.get(reg) if (possible_mem_writes_reg): direct_writes = possible_mem_writes[reg].get(0, []) else: direct_writes = [] padding = constraint.getValidPadding(Arch.octets()) for write_gadget in direct_writes: for cst_to_reg_chain in cst_to_reg_chains: # Pad the gadgets write_chain = ROPChain([write_gadget]) for i in range(0, write_gadget.spInc - Arch.octets(), Arch.octets()): write_chain.addPadding(padding) full_chain = cst_to_reg_chain.addChain(write_chain, new=True) if (len(full_chain) <= clmax): res.append(full_chain) if (len(res) >= n): return res # 1.B. return res
def _MEMtoREG_transitivity(reg, arg2, env, n=1): """ Perform reg <- inter <- mem(arg2) """ ID = StrategyType.MEMtoREG_TRANSITIVITY ## Test for special cases # Test lmax if (env.getLmax() <= 0): return FailRecord(lmax=True) # Limit number of calls to ... elif (env.nbCalls(ID) >= 99): return FailRecord() # Check if previous call was already MEMtoREG_transitivity # Reason: we handle the transitivity with REGtoREG transitivity # so no need to do it also recursively with this one ;) elif (env.callsHistory()[-1] == ID): return FailRecord() # Set env env.addCall(ID) ########################### res = [] res_fail = FailRecord() for inter in Arch.registers(): if (inter == reg or inter in env.getConstraint().getRegsNotModified() or inter == Arch.ipNum() or inter == Arch.spNum()): continue # Find arg1 <- inter inter_to_reg = _search(QueryType.REGtoREG, reg, (inter, 0), env, n) if (inter_to_reg): min_len = min([len(chain) for chain in inter_to_reg]) # Try to find inter <- arg2 env.subLmax(min_len) env.addUnusableReg(reg) arg2_to_inter = _search(QueryType.MEMtoREG, inter, arg2, env, n) env.removeUnusableReg(reg) env.addLmax(min_len) if (not arg2_to_inter): res_fail.merge(arg2_to_inter) continue res += [chain1.addChain(chain2, new=True) for chain1 in arg2_to_inter \ for chain2 in inter_to_reg if len(chain1)+len(chain2) <= env.getLmax() ] else: res_fail.merge(inter_to_reg) # Did we get enough chains ? if (len(res) >= n): break ######################## # Restore env env.removeCall(ID) return res[:n] if res else res_fail
def _CSTtoREG_transitivity(reg, cst, constraint, assertion, n=1, clmax=LMAX, comment=None, maxdepth=4): """ Perform REG1 <- CST with REG1 <- REG2 <- CST """ # Test clmax if (clmax <= 0): return [] res = [] for inter in range(0, Arch.ssaRegCount): if (inter == reg or inter in constraint.getRegsNotModified() or inter == Arch.ipNum() or inter == Arch.spNum()): continue # Find reg <- inter REGtoREG_record = SearchRecord(maxdepth=maxdepth) REGtoREG_record.unusable_REGtoREG.append(reg) inter_to_reg = search(QueryType.REGtoREG, reg, (inter, 0), constraint, assertion, n, clmax, record=REGtoREG_record) if (inter_to_reg): # We found ROPChains s.t reg <- inter # Now we want inter <- cst cst_to_inter = _basic(QueryType.CSTtoREG, inter, cst, constraint, assertion, n / len(inter_to_reg) + 1, clmax - 1) for chain2 in inter_to_reg: for chain1 in cst_to_inter: if (len(chain1) + len(chain2) <= clmax): res.append(chain1.addChain(chain2, new=True)) if (len(res) < n): cst_to_inter = _CSTtoREG_pop(inter, cst, constraint, assertion, n / (len(inter_to_reg)) + 1, clmax - 1, comment) for chain2 in inter_to_reg: for chain1 in cst_to_inter: if (len(chain1) + len(chain2) <= clmax): res.append(chain1.addChain(chain2, new=True)) # Did we get enough chains ? if (len(res) >= n): return res[:n] # Return what we got return res
def _REGtoMEM_transitivity(arg1, arg2, env, n=1): """ reg <- arg2 mem(arg1) <- reg """ ID = StrategyType.REGtoMEM_TRANSITIVITY ## Test for special cases # Test lmax if (env.getLmax() <= 0): return FailRecord(lmax=True) # Limit number of calls to ... elif (env.nbCalls(ID) >= 99): return FailRecord() # Check if previous call was already REGtoMEM_transitivity # Reason: we handle the transitivity with REGtoREG transitivity # so no need to do it also recursively with this one ;) elif (env.callsHistory()[-1] == ID): return FailRecord() # Set env env.addCall(ID) ################################### res = [] res_fail = FailRecord() for inter in Arch.registers(): if (inter == arg2[0] or inter in env.getConstraint().getRegsNotModified() or inter == Arch.ipNum() or inter == Arch.spNum()): continue # Find inter <- arg2 arg2_to_inter = _search(QueryType.REGtoREG, inter, (arg2[0], arg2[1]), env, n) if (arg2_to_inter): len_min = min([len(chain) for chain in arg2_to_inter]) # Try to find mem(arg1) <- inter env.subLmax(len_min) env.addUnusableReg(arg2[0]) inter_to_mem = _search(QueryType.REGtoMEM, arg1, (inter, 0), env) env.removeUnusableReg(arg2[0]) env.addLmax(len_min) if (not inter_to_mem): res_fail.merge(inter_to_mem) continue res += [chain1.addChain(chain2, new=True) for chain1 in arg2_to_inter\ for chain2 in inter_to_mem if len(chain1)+len(chain2) <= env.getLmax()] if (len(res) >= n): break else: res_fail.merge(arg2_to_inter) ##################################### # Resotre env env.removeCall(ID) return res if res else res_fail
def _CSTtoREG_pop(reg, cst, constraint, assertion, n=1, clmax=LMAX, comment=None): """ Returns a payload that puts cst into register reg by poping it from the stack """ # Test clmax if (clmax <= 0): return [] # Test n if (n < 1): return [] # Check if the cst is incompatible with the constraint if (not constraint.badBytes.verifyAddress(cst)): return [] if (not comment): comment = "Constant: " + string_bold("0x{:x}".format(cst)) # Direct pop from the stack res = [] if (reg == Arch.ipNum()): constraint2 = constraint.remove([CstrTypeID.CHAINABLE]) else: constraint2 = constraint.add(Chainable(ret=True)) possible = DBPossiblePopOffsets(reg, constraint2, assertion) for offset in sorted(filter(lambda x: x >= 0, possible.keys())): # If offsets are too big to fit in the lmax just break if (offset > clmax * Arch.octets()): break # Get possible gadgets possible_gadgets = [g for g in possible[offset]\ if g.spInc >= Arch.octets() \ and g.spInc - Arch.octets() > offset \ and (g.spInc/Arch.octets()-1) <= clmax] # Test if padding is too much for clmax # Pad the gadgets padding = constraint.getValidPadding(Arch.octets()) for gadget in possible_gadgets: chain = ROPChain([gadget]) for i in range(0, gadget.spInc - Arch.octets(), Arch.octets()): if (i == offset): chain.addPadding(cst, comment) else: chain.addPadding(padding) if (len(chain) <= clmax): res.append(chain) if (len(res) >= n): return res return res
def build_call(funcName, funcArgs, constraint, assertion): # Find the address of the fonction (funcName2, funcAddr) = getFunctionAddress(funcName) if (funcName2 is None): return "Couldn't find function '{}' in the binary".format(funcName) # Check if bad bytes in function address if (not constraint.badBytes.verifyAddress(funcAddr)): return "'{}' address ({}) contains bad bytes".format( funcName2, string_special('0x' + format(funcAddr, '0' + str(Arch.octets() * 2) + 'x'))) # Find a gadget for the fake return address offset = len( funcArgs) * 8 - 8 # Because we do +8 at the beginning of the loop skip_args_chains = [] i = 4 while (i > 0 and (not skip_args_chains)): offset += 8 skip_args_chains = search(QueryType.MEMtoREG, Arch.ipNum(), \ (Arch.spNum(),offset), constraint, assertion, n=1) i -= 1 if (not skip_args_chains): return "Couldn't build ROP-Chain" skip_args_chain = skip_args_chains[0] # Build the ropchain with the arguments args_chain = ROPChain() arg_n = len(funcArgs) for arg in reversed(funcArgs): if (isinstance(arg, int)): args_chain.addPadding(arg, comment="Arg{}: {}".format( arg_n, string_ropg(hex(arg)))) arg_n -= 1 else: return "Type of argument '{}' not supported yet :'(".format(arg) # Build call chain (function address + fake return address) call_chain = ROPChain() call_chain.addPadding(funcAddr, comment=string_ropg(funcName2)) skip_args_addr = int( validAddrStr(skip_args_chain.chain[0], constraint.getBadBytes(), Arch.bits()), 16) call_chain.addPadding(skip_args_addr, comment="Address of: " + string_bold(str(skip_args_chain.chain[0]))) return call_chain.addChain(args_chain)
def build_call_linux86(funcName, funcArgs, constraint, assertion, clmax=None, optimizeLen=False): # Find the address of the fonction (funcName2, funcAddr) = getFunctionAddress(funcName) if( funcName2 is None ): return "Couldn't find function '{}' in the binary".format(funcName) # Check if bad bytes in function address if( not constraint.badBytes.verifyAddress(funcAddr) ): return "'{}' address ({}) contains bad bytes".format(funcName2, string_special('0x'+format(funcAddr, '0'+str(Arch.octets()*2)+'x'))) # Check if lmax too small if( (1 + len(funcArgs) + (lambda x: 1 if len(x)>0 else 0)(funcArgs)) > clmax ): return "Not enough bytes to call function '{}'".format(funcName) # Find a gadget for the fake return address if( funcArgs ): offset = (len(funcArgs)-1)*Arch.octets() # Because we do +octets() at the beginning of the loop skip_args_chains = [] i = 4 # Try 4 more maximum while( i > 0 and (not skip_args_chains)): offset += Arch.octets() skip_args_chains = search(QueryType.MEMtoREG, Arch.ipNum(), \ (Arch.spNum(),offset), constraint, assertion, n=1, optimizeLen=optimizeLen) i -= 1 if( not skip_args_chains ): return "Couldn't build ROP-Chain" skip_args_chain = skip_args_chains[0] else: # No arguments skip_args_chain = None # Build the ropchain with the arguments args_chain = ROPChain() arg_n = len(funcArgs) for arg in reversed(funcArgs): if( isinstance(arg, int) ): args_chain.addPadding(arg, comment="Arg{}: {}".format(arg_n, string_ropg(hex(arg)))) arg_n -= 1 else: return "Type of argument '{}' not supported yet :'(".format(arg) # Build call chain (function address + fake return address) call_chain = ROPChain() call_chain.addPadding(funcAddr, comment=string_ropg(funcName2)) if( funcArgs ): skip_args_addr = int( validAddrStr(skip_args_chain.chain[0], constraint.getBadBytes(), Arch.bits()) ,16) call_chain.addPadding(skip_args_addr, comment="Address of: "+string_bold(str(skip_args_chain.chain[0]))) return call_chain.addChain(args_chain)
def _REGtoREG_transitivity(arg1, arg2, constraint, assertion, record, n=1, clmax=LMAX): """ Perform REG1 <- REG2+CST with REG1 <- REG3 <- REG2+CST """ # Test clmax if (clmax <= 0): return [] # If reg1 <- reg1 + 0, return if (arg1 == arg2[0] and arg2[1] == 0): return [] res = [] for inter_reg in range(0, Arch.ssaRegCount): if( inter_reg == arg1 or (inter_reg == arg2[0] and arg2[1]==0)\ or (inter_reg in record.unusable_REGtoREG) or inter_reg == Arch.ipNum()\ or (inter_reg == Arch.spNum()) ): continue # Find reg1 <- inter_reg without using arg2 record.unusable_REGtoREG.append(arg2[0]) inter_to_arg1_list = search(QueryType.REGtoREG, arg1, (inter_reg, 0), \ constraint, assertion, n, clmax=clmax-1, record=record ) record.unusable_REGtoREG.remove(arg2[0]) if (not inter_to_arg1_list): continue # Find inter_reg <- arg2 without using arg1 record.unusable_REGtoREG.append(arg1) n2 = n / len(inter_to_arg1_list) if (n2 == 0): n2 = 1 for arg2_to_inter in search(QueryType.REGtoREG, inter_reg, arg2, \ constraint, assertion, n2, clmax=clmax-1, record=record): for inter_to_arg1 in inter_to_arg1_list: if (len(inter_to_arg1) + len(arg2_to_inter) <= clmax): res.append(arg2_to_inter.addChain(inter_to_arg1, new=True)) if (len(res) >= n): return res record.unusable_REGtoREG.remove(arg1) return res
def MEMtoREG_transitivity(reg, arg2, constraint, assertion, n=1, clmax=LMAX): if (clmax <= 0): return [] res = [] for inter in range(0, Arch.ssaRegCount): if (inter == reg or inter in constraint.getRegsNotModified() or inter == Arch.ipNum() or inter == Arch.spNum()): continue # Find arg1 <- inter REGtoREG_record = SearchRecord(maxdepth=4) REGtoREG_record.unusable_REGtoREG.append(reg) inter_to_reg = search(QueryType.REGtoREG, reg, (inter, 0), constraint, assertion, n, clmax - 1, record=REGtoREG_record) if (inter_to_reg): len_min = min([len(chain) for chain in inter_to_reg]) # Try to find inter <- arg2 # First strategy basic arg2_to_inter = _basic(QueryType.MEMtoREG, inter, arg2, constraint.add(Chainable(ret=True)), assertion, n, clmax - len_min) res += [chain1.addChain(chain2, new=True) for chain1 in arg2_to_inter \ for chain2 in inter_to_reg if len(chain1)+len(chain2) <= clmax ] # Second strategy read reg (TODO) if (len(res) < n): pass # Did we get enough chains ? if (len(res) >= n): return res # Return the best we got return res
def __init__(self, addr_list, raw): """ addr_list = list of addresses of the gadget (if duplicate gadgets) raw = raw string of gadget asm """ # Check the type of the gadget # Check for 'int 0x80' gadgets if (raw == '\xcd\x80' and Arch.currentIsIntel()): self.type = GadgetType.INT80 self.asmStr = 'int 0x80' self.hexStr = '\\xcd\\x80' self.addrList = addr_list self.nbInstr = self.nbInstrREIL = 1 self.semantics = Semantics() return # Check for 'syscall' gadgets elif (raw == '\x0f\x05' and Arch.currentIsIntel()): self.type = GadgetType.SYSCALL self.asmStr = 'syscall' self.hexStr = '\\x0f\\x05' self.addrList = addr_list self.nbInstr = self.nbInstrREIL = 1 self.semantics = Semantics() return # Translate raw assembly into REIL # Then compute the Graph and its semantics try: (irsb, ins) = Arch.currentArch.asmToREIL(raw) except Arch.ArchException as e: raise GadgetException(str(e)) try: self.graph = REILtoGraph(irsb) self.semantics = self.graph.getSemantics() except GraphException as e: raise GadgetException("(In {}) - ".format('; '.join( str(i) for i in ins)) + str(e)) self.type = GadgetType.REGULAR # Possible addresses self.addrList = addr_list # String representations self.asmStr = '; '.join(str(i) for i in ins) self.hexStr = '\\x' + '\\x'.join("{:02x}".format(ord(c)) for c in raw) # Length of the gadget self.nbInstr = len(ins) self.nbInstrREIL = len(irsb) # List of modified registers # And of memory-read accesses self._modifiedRegs = [] self._memoryReads = [] for reg_num in list( set([reg.num for reg in self.semantics.registers.keys()])): # Check if there is an empty semantics if (not self.getSemantics(reg_num)): #self.semantics.registers.pop(reg) log("Gadget ({}) : empty semantics for {}"\ .format(self.asmStr, Arch.r2n(reg_num))) self._modifiedRegs.append(reg_num) continue # Get modified reg if ((SSAExpr(reg_num, 0) != self.getSemantics(reg_num)[0].expr)): self._modifiedRegs.append(reg_num) # Get memory reads for pair in self.getSemantics(reg_num): self._memoryReads += [m[0] for m in pair.expr.getMemAcc()] self._modifiedRegs = list(set(self._modifiedRegs)) # SP Increment if (self.type != GadgetType.REGULAR): self.spInc = None else: sp_num = Arch.spNum() if (not sp_num in self.graph.lastMod): self.spInc = 0 else: sp = SSAReg(sp_num, self.graph.lastMod[sp_num]) if (len(self.semantics.get(sp)) == 1): (isInc, inc ) = self.semantics.get(sp)[0].expr.isRegIncrement(sp_num) if (isInc): self.spInc = inc else: self.spInc = None else: self.spInc = None # Return type self.retType = RetType.UNKNOWN self.retValue = None if (self.type == GadgetType.REGULAR): ip_num = Arch.ipNum() ip = SSAReg(ip_num, self.graph.lastMod[ip_num]) sp_num = Arch.spNum() # DEBUG before this test if( self.spInc != None ): for p in self.semantics.get(ip): if (p.cond.isTrue()): if (isinstance(p.expr, MEMExpr)): addr = p.expr.addr (isInc, inc) = addr.isRegIncrement(sp_num) # Normal ret if the final value of the IP is value that was in memory before the last modification of SP ( i.e final_IP = MEM[final_sp - size_of_a_register ) if (isInc and self.spInc and inc == (self.spInc - (Arch.currentArch.octets))): self.retType = RetType.RET self.retValue = p.expr elif (isinstance(p.expr, SSAExpr)): self.retValue = p.expr # Try to detect gadgets ending by 'call' if (ins[-1]._mnemonic[:4] == "call"): self.retType = RetType.CALL else: self.retType = RetType.JMP
def build_dshell(shellcode, constraint, assertion, address, limit, lmax): """ Returns a PwnChain() instance or None """ # Build exploit ################# res = PwnChain() #Find address for the payload if (not address): # Get the .bss address # TODO notify("Getting delivery address for shellcode") address = getSectionAddress('.bss') addr_str = ".bss" if (not address): verbose("Couldn't find .bss address") return [] else: addr_str = hex(address) if (not limit): limit = address + Arch.minPageSize() # Deliver shellcode notify("Building chain to copy shellcode in memory") verbose("{}/{} bytes available".format(lmax * Arch.octets(), lmax * Arch.octets())) (shellcode_address, STRtoMEM_chain) = STRtoMEM(shellcode, address, constraint, assertion, limit=limit, lmax=lmax, addr_str=addr_str, hex_info=True, optimizeLen=True) address = shellcode_address addr_str = hex(address) if (not STRtoMEM_chain): verbose("Could not copy shellcode into memory") return None # Building mprotect notify("Building mprotect() chain") # Getting page to make executable # Arg of mprotect MUST be a valid multiple of page size over_page_size = address % Arch.minPageSize() page_address = address - over_page_size length = len(shellcode) + 1 + over_page_size flag = 7 lmax2 = lmax - len(STRtoMEM_chain) verbose("{}/{} bytes available".format(lmax2 * Arch.octets(), lmax * Arch.octets())) if (lmax2 <= 0): return None if (Arch.currentArch == Arch.ArchX86): mprotect_chain = build_mprotect32(page_address, length, flag, constraint.add(Chainable(ret=True)), assertion, clmax=lmax2 - 2, optimizeLen=True) elif (Arch.currentArch == Arch.ArchX64): mprotect_chain = build_mprotect64(page_address, length, flag, constraint.add(Chainable(ret=True)), assertion, clmax=lmax2 - 2, optimizeLen=True) else: mprotect_chain = None verbose("mprotect call not supported for architecture {}".format( Arch.currentArch.name)) return None if (not mprotect_chain): return None verbose("Done") # Jump to shellcode notify("Searching chain to jump to shellcode") verbose("{}/{} bytes available".format( (lmax2 - len(mprotect_chain)) * Arch.octets(), lmax * Arch.octets())) jmp_shellcode_chains = search(QueryType.CSTtoREG, Arch.ipNum(), address, constraint, assertion, clmax=lmax - len(STRtoMEM_chain) - len(mprotect_chain), optimizeLen=True) if (not jmp_shellcode_chains): verbose("Couldn't find a jump to the shellcode") return None verbose("Done") notify("Done") # Build PwnChain res and return res.add(mprotect_chain, "Call mprotect({},{},{})".format(hex(page_address), length, flag)) res.add(STRtoMEM_chain, "Copy shellcode to {}".format(addr_str)) res.add(jmp_shellcode_chains[0], "Jump to shellcode (address {})".format(addr_str)) return res
def _adjust_ret(qtype, arg1, arg2, env, n): """ Search with basic but adjust the bad returns they have """ global LMAX ID = StrategyType.ADJUST_RET ## Test for special cases # Test lmax if (env.getLmax() <= 0): return FailRecord(lmax=True) # Limit number of calls to ... elif (env.nbCalls(ID) >= 2): return FailRecord() # Test for ip # Reason: can not adjust ip if ip is the # target of the query :/ elif (arg1 == Arch.ipNum()): return FailRecord() # Set env env.addCall(ID) saved_adjust_ret = env.getImpossible_adjust_ret().copy() ######################################## res = [] res_fail = FailRecord() padding = env.getConstraint().getValidPadding(Arch.octets()) # Get possible gadgets constraint = env.getConstraint() env.setConstraint(constraint.add(Chainable(jmp=True, call=True))) possible = _basic(qtype, arg1, arg2, env, 10 * n) env.setConstraint(constraint) if (not possible): res_fail.merge(possible) possible = [] # Try to adjust them for chain in possible: g = chain.chain[0] ret_reg = g.retValue.reg.num # Check if we already know that ret_reg can't be adjusted if (env.checkImpossible_adjust_ret(ret_reg)): continue #Check if ret_reg not modified within the gadget elif (ret_reg in g.modifiedRegs()): continue # Check if stack is preserved elif (g.spInc is None): continue # Find adjustment if (g.spInc < 0): offset = -1 * g.spInc padding_length = 0 else: padding_length = g.spInc / Arch.octets() if (g.retType == RetType.JMP): offset = 0 else: offset = Arch.octets() if (isinstance(arg1, int)): arg1_reg = arg1 else: arg1_reg = arg1[0] # Get adjustment gadgets env.setConstraint(constraint.add(RegsNotModified([arg1_reg]))) saved_lmax = env.getLmax() env.setLmax(LMAX) adjust_gadgets = _search(QueryType.MEMtoREG, Arch.ipNum(), \ (Arch.spNum(),offset), env, n=1) env.setConstraint(constraint) env.setLmax(saved_lmax) if (not adjust_gadgets): res_fail.merge(adjust_gadgets) continue else: adjust_addr = int(validAddrStr(adjust_gadgets[0].chain[0],\ constraint.getBadBytes(), Arch.bits()), 16) # Find gadgets to put the gadget address in the jmp/call register if (isinstance(arg2, int)): arg2_reg = arg2 else: arg2_reg = arg2[0] env.setConstraint(constraint.add(RegsNotModified([arg2_reg]))) env.subLmax(1 + padding_length) env.pushComment( StrategyType.CSTtoREG_POP, "Address of " + string_bold(str(adjust_gadgets[0].chain[0]))) adjust = _search(QueryType.CSTtoREG, ret_reg, adjust_addr, env, n=1) env.popComment(StrategyType.CSTtoREG_POP) env.addLmax(1 + padding_length) env.setConstraint(constraint) if (adjust): res.append(adjust[0].addGadget(g).addPadding(padding, n=padding_length)) if (len(res) >= n): break else: # Update the search record to say that reg_ret cannot be adjusted env.addImpossible_adjust_ret(ret_reg) res_fail.merge(adjust) ######################################## # Restore env env.impossible_adjust_ret = saved_adjust_ret env.removeCall(ID) return res if res else res_fail
def _CSTtoMEM_write(arg1, cst, env, n=1): """ reg <- cst mem(arg2) <- reg """ ID = StrategyType.CSTtoMEM_WRITE ## Test for special cases # Test lmax if (env.getLmax() <= 0): return FailRecord(lmax=True) # Limit number of calls to ... elif (env.nbCalls(ID) >= 99): return FailRecord() # Set env env.addCall(ID) ###################################### res = [] res_fail addr_reg = arg1[0] addr_cst = arg1[1] # 1. First strategy (direct) # reg <- cst # mem(arg1) <- reg for reg in Arch.registers(): if (reg == Arch.ipNum() or reg == Arch.spNum() or reg == addr_reg): continue # Find reg <- cst constraint = env.getConstraint() env.setConstraint(constraint.add(RegsNotModified([addr_reg]))) env.subLmax(1) cst_to_reg_chains = _search(QueryType.CSTtoREG, reg, cst, env, n) env.addLmax(1) env.setConstraint(constraint) if (not cst_to_reg_chains): res_fail.merge(cst_to_reg_chains) continue # Search for mem(arg1) <- reg # We get all reg2,cst2 s.t mem(arg1) <- reg2+cst2 possible_mem_writes = DBPossibleMemWrites(addr_reg, addr_cst, env.getConstraint(), env.getAssertion(), n=1) # 1.A. Ideally we look for reg2=reg and cst2=0 (direct_writes) possible_mem_writes_reg = possible_mem_writes.get(reg) if (possible_mem_writes_reg): direct_writes = possible_mem_writes[reg].get(0, []) else: direct_writes = [] padding = constraint.getValidPadding(Arch.octets()) for write_gadget in direct_writes: for cst_to_reg_chain in cst_to_reg_chains: # Pad the gadgets write_chain = ROPChain([write_gadget]) for i in range(0, write_gadget.spInc - Arch.octets(), Arch.octets()): write_chain.addPadding(padding) full_chain = cst_to_reg_chain.addChain(write_chain, new=True) if (len(full_chain) <= env.getLmax()): res.append(full_chain) if (len(res) >= n): break if (len(res) >= n): break if (len(res) >= n): break # 1.B. Otherwise we try to adjust the cst2 # To be implemented # 2d Strategy: indirect # reg <- arg2 - cst # mem(arg1) <- reg + cst # TO IMPLEMENT IN ANOTHER FUNCTION ! ################### # Restore env env.removeCall(ID) return res if res else res_fail
def STRtoMEM_memcpy(string, addr, constraint, assertion, lmax=STR_TO_MEM_LMAX, addr_str=None, hex_info=False): """ MEMCPY STRATEGY Copy the string using memcpy function """ if (not addr_str): addr_str = "0x" + format(addr, '0' + str(Arch.octets() * 2) + 'x') # Getting strcpy function (func_name, func_addr) = getFunctionAddress('memcpy') if (not func_addr): verbose('Could not find memcpy function') return None elif (not constraint.badBytes.verifyAddress(func_addr)): verbose("memcpy address ({}) contains bad bytes".format( hex(func_addr))) return None # We decompose the string in substrings to be copied substrings_addr = findBytes(string, badBytes=constraint.getBadBytes()) if (not substrings_addr): return None elif (len(substrings_addr) * 5 > lmax): verbose( "Memcpy: ROP-Chain too long (length: {}, available bytes: {}) ". format( len(substrings_addr) * 5 * Arch.octets(), lmax * Arch.octets())) return None # Get a pop-pop-pop-ret gadget pppr_chains = _basic(QueryType.MEMtoREG, Arch.ipNum(), [Arch.spNum(), 3 * Arch.octets()], constraint.add( StackPointerIncrement(4 * Arch.octets())), assertion, clmax=1, noPadding=True) if (not pppr_chains): verbose("Memcpy: Could not find suitable pop-pop-pop-ret gadget") return None pppr_gadget = pppr_chains[0].chain[0] # Get the first gadget # Build chain res = ROPChain() offset = 0 custom_stack = addr for (substring_addr, substring_str) in substrings_addr: if (hex_info): substring_info = "'" + '\\x' + '\\x'.join( ["%02x" % ord(c) for c in substring_str]) + "'" else: substring_info = "'" + substring_str + "'" res.addPadding(func_addr, comment=string_ropg(func_name)) res.addGadget(pppr_gadget) res.addPadding(len(substring_str), comment="Arg3: " + string_ropg(str(len(substring_str)))) res.addPadding(substring_addr, comment="Arg2: " + string_ropg(substring_info)) res.addPadding(custom_stack, comment="Arg1: " + string_ropg("{} + {}".format(addr_str, offset))) # Adjust custom_stack = custom_stack + len(substring_str) offset = offset + len(substring_str) return res
def _adjust_ret(qtype, arg1, arg2, constraint, assertion, n, clmax=LMAX, record=None, comment=""): """ Search with basic but adjust the bad returns they have """ # Test clmax if (clmax <= 0): return [] # Test for ip if (arg1 == Arch.ipNum()): return [] # Test for search record if (record is None): record = SearchRecord() res = [] possible = _basic(qtype, arg1, arg2, \ constraint.add(Chainable(jmp=True, call=True)), assertion, n) padding = constraint.getValidPadding(Arch.currentArch.octets) for chain in possible: g = chain.chain[0] ret_reg = g.retValue.reg.num # Check if we already know that ret_reg can't be adjusted if (record.impossible_AdjustRet.check(ret_reg)): continue #Check if ret_reg not modified within the gadget if (ret_reg in g.modifiedRegs()): continue # Check if stack is preserved if (g.spInc is None): continue # Find adjustment if (g.spInc < 0): offset = -1 * g.spInc padding_length = 0 else: padding_length = g.spInc if (g.retType == RetType.JMP): offset = 0 else: offset = Arch.octets() adjust_gadgets = search(QueryType.MEMtoREG, Arch.ipNum(), \ (Arch.spNum(),offset), constraint.add(RegsNotModified([arg1])), assertion, n=1, record=record) if (not adjust_gadgets): continue else: adjust_addr = int(validAddrStr(adjust_gadgets[0].chain[0],\ constraint.getBadBytes(), Arch.bits()), 16) # Put the gadget address in the register adjust = search(QueryType.CSTtoREG, ret_reg, adjust_addr, \ constraint.add(RegsNotModified([arg2[0]])), assertion, n=1, clmax=clmax-len(chain),record=record,\ comment="Address of "+string_bold(str(adjust_gadgets[0].chain[0]))) if (adjust): res.append(adjust[0].addGadget(g).addPadding(padding, n=padding_length)) if (len(res) >= n): return res else: # Update the search record to say that reg_ret cannot be adjusted record.impossible_AdjustRet.add(ret_reg) return res
def _CSTtoREG_pop(reg, cst, env, n=1): """ Returns a payload that puts cst into register reg by poping it from the stack """ ID = StrategyType.CSTtoREG_POP ## Test for special cases # Test lmax if (env.getLmax() <= 0): return FailRecord(lmax=True) # Limit number of calls to ... elif (env.nbCalls(ID) >= 99): return FailRecord() # Check if the cst is in badBytes elif (not env.getConstraint().badBytes.verifyAddress(cst)): return FailRecord() # Set env env.addCall(ID) # Get comment if (env.hasComment(ID)): envHadComment = True comment = env.popComment(ID) else: envHadComment = False comment = "Constant: " + string_bold("0x{:x}".format(cst)) ######################## # Direct pop from the stack res = [] res_fail = FailRecord() # Adapt constraint if ip <- cst if (reg != Arch.ipNum()): constraint2 = env.getConstraint().add(Chainable(ret=True)) else: constraint2 = env.getConstraint() possible = DBPossiblePopOffsets(reg, constraint2, env.getAssertion()) for offset in sorted(filter(lambda x: x >= 0, possible.keys())): # If offsets are too big to fit in the lmax just break if (offset > env.getLmax() * Arch.octets()): break # Get possible gadgets possible_gadgets = [g for g in possible[offset]\ if g.spInc >= Arch.octets() \ and g.spInc - Arch.octets() > offset \ and (g.spInc/Arch.octets()-1) <= env.getLmax()] # Test if padding is too much for clmax # Pad the gadgets padding = env.getConstraint().getValidPadding(Arch.octets()) for gadget in possible_gadgets: chain = ROPChain([gadget]) for i in range(0, gadget.spInc - Arch.octets(), Arch.octets()): if (i == offset): chain.addPadding(cst, comment) else: chain.addPadding(padding) if (len(chain) <= env.getLmax()): res.append(chain) if (len(res) >= n): break if (len(res) >= n): break ######################### # Restore env env.removeCall(ID) if (envHadComment): env.pushComment(ID, comment) return res if res else res_fail
def _REGtoREG_transitivity(arg1, arg2, env, n=1): """ Perform REG1 <- REG2+CST with REG1 <- REG3 <- REG2+CST """ ID = StrategyType.REGtoREG_TRANSITIVITY ## Test for special cases # Test lmax if (env.getLmax() <= 0): return FailRecord(lmax=True) # If reg1 <- reg1 + 0, return elif (arg1 == arg2[0] and arg2[1] == 0): return FailRecord() # Limit number of calls to REGtoREG transitivity elif (env.callsHistory()[-2:] == [ID, ID]): return FailRecord() # Set env env.addCall(ID) # Search res = [] res_fail = FailRecord() for inter_reg in Arch.registers(): if( inter_reg == arg1 or (inter_reg == arg2[0] and arg2[1]==0)\ or (env.checkImpossible_REGtoREG(arg1, inter_reg, 0))\ or (env.checkImpossible_REGtoREG(inter_reg, arg2[0], arg2[1]))\ or inter_reg == Arch.ipNum() or inter_reg == Arch.spNum() ): continue # Find reg1 <- inter_reg without using arg2 env.addUnusableReg(arg2[0]) env.subLmax(1) inter_to_arg1_list = _search(QueryType.REGtoREG, arg1, (inter_reg, 0), env, n) env.removeUnusableReg(arg2[0]) env.addLmax(1) if (not inter_to_arg1_list): res_fail.merge(inter_to_arg1_list) continue else: min_len_chain = min([len(chain) for chain in inter_to_arg1_list]) # Find inter_reg <- arg2 without using arg1 env.addUnusableReg(arg1) env.subLmax(min_len_chain) n2 = n / len(inter_to_arg1_list) if (n2 == 0): n2 = 1 arg2_to_inter_chains = _search(QueryType.REGtoREG, inter_reg, arg2, env, n2) if (not arg2_to_inter_chains): res_fail.merge(arg2_to_inter_chains) continue for arg2_to_inter in arg2_to_inter_chains: for inter_to_arg1 in inter_to_arg1_list: if (len(inter_to_arg1) + len(arg2_to_inter) <= env.getLmax()): res.append(arg2_to_inter.addChain(inter_to_arg1, new=True)) if (len(res) >= n): break if (len(res) >= n): break env.addLmax(min_len_chain) env.removeUnusableReg(arg1) if (len(res) >= n): break # Restore env env.removeCall(ID) return res if res else res_fail
def _basic(qtype, arg1, arg2, env, n=1, enablePreConds=False): """ Search for gadgets basic method ( without chaining ) Direct Database check """ if (env.getLmax() <= 0): return FailRecord(lmax=True) if (env.getNoPadding()): maxSpInc = None else: maxSpInc = env.getLmax() * Arch.octets() # Check for special gadgets if (qtype == QueryType.INT80 or qtype == QueryType.SYSCALL): gadgets = DBSearch(qtype, arg1, arg2, env.getConstraint(), env.getAssertion(), n=n, maxSpInc=maxSpInc) res = [ROPChain().addGadget(g) for g in gadgets] return res # Check if the type is IP <- ... # In this case we remove the CHAINABLE constraint which makes no sense if (arg1 == Arch.ipNum()): constraint2 = env.getConstraint().remove([CstrTypeID.CHAINABLE]) else: constraint2 = env.getConstraint() # Check to add assertions when looking for Memory gadgets if (qtype == QueryType.CSTtoMEM or qtype == QueryType.REGtoMEM): assertion2 = env.getAssertion().add( RegsNoOverlap([(arg1[0], Arch.spNum())])) else: assertion2 = env.getAssertion() # Regular gadgets # maxSpInc -> +1 because we don't count the ret but -1 because the gadget takes one place gadgets = DBSearch(qtype, arg1, arg2, constraint2, assertion2, n, enablePreConds=enablePreConds, maxSpInc=maxSpInc) if (enablePreConds): return [(ROPChain().addGadget(g[0]), g[1]) for g in gadgets] elif (env.getNoPadding()): return [ROPChain().addGadget(g) for g in gadgets] else: res = [] padding = constraint2.getValidPadding(Arch.octets()) for g in gadgets: chain = ROPChain().addGadget(g) # Padding the chain if possible if (g.spInc > 0): for i in range(0, g.spInc / Arch.octets() - 1): chain.addPadding(padding) # Adding to the result res.append(chain) if (len(res) == 0): return FailRecord() else: return res
def store_constant_address(qtype, cst_addr, value, constraint=None, assertion=None, clmax=None, optimizeLen=False): """ Does a XXXtoMEM kind of query BUT the memory address is a simple constant ! Expected qtypes are only XXXtoMEM cst_addr is the store address value is the value to store, a single cst or a couple (reg,cst) """ if (clmax is None): clmax = STORE_CONSTANT_ADDRESS_LMAX elif (clmax <= 0): return None if (constraint is None): constr = Constraint() else: constr = constraint if (assertion is None): a = Assertion() else: a = assertion # Tranform the query type if (qtype == QueryType.CSTtoMEM): qtype2 = QueryType.CSTtoREG elif (qtype == QueryType.REGtoMEM): qtype2 = QueryType.REGtoREG elif (qtype == QueryType.MEMtoREG): qtype2 = QueryType.MEMtoREG else: raise Exception( "Query type {} should not appear in this function!".format(qtype)) tried_values = [] tried_cst_addr = [] best = None # If optimizeLen shortest = clmax # Shortest ROPChain found if optimizeLen ;) for ((addr_reg, addr_cst), (reg,cst), gadget) in \ sorted(DBAllPossibleWrites(constr.add(Chainable(ret=True)), a), \ key=lambda x: 0 if (x[1] == value) else 1) : # DOn't use rip or rsp... if( reg == Arch.ipNum() or reg == Arch.spNum()\ or addr_reg == Arch.ipNum() or addr_reg == Arch.spNum()): continue res = None # Check if directly the registers we want to write ;) value_is_reg = False value_to_reg = [] addr_to_reg = [] if ((reg, cst) == value): value_to_reg = [ROPChain()] value_is_reg = True # adapt value if (not isinstance(value, tuple)): adjusted_value = value - cst else: adjusted_value = (value[0], value[1] - cst) adjusted_cst_addr = cst_addr - addr_cst # Get spInc gadget_paddingLen = (gadget.spInc / Arch.octets()) - 1 # Check if tried before if ((reg, cst) in tried_values): continue elif ((addr_reg, addr_cst) in tried_cst_addr): continue ### Try to do reg first then addr_reg # Try to put the value into reg clmax2 = shortest - gadget_paddingLen - 1 if (not value_is_reg): value_to_reg = search(qtype2, reg, adjusted_value, constr, a, clmax=clmax2, n=1, optimizeLen=optimizeLen) if (not value_to_reg): tried_values.append((reg, cst)) continue else: clmax2 = clmax2 - len(value_to_reg[0]) # Try to put the cst_addr in addr_reg addr_to_reg = search(QueryType.CSTtoREG, addr_reg, adjusted_cst_addr, constr.add(RegsNotModified([reg])), a, clmax=clmax2, n=1, optimizeLen=optimizeLen) if (addr_to_reg): # If we found a solution # Combine them and return # Padd the gadget res = value_to_reg[0].addChain(addr_to_reg[0]).addGadget(gadget) if (gadget.spInc > 0): padding_value = constr.getValidPadding(Arch.octets()) res = res.addPadding(padding_value, n=(gadget.spInc / Arch.octets()) - 1) if (optimizeLen): if (best): best = min(best, res) else: best = res shortest = len(best) else: return res ### Try to do addr_reg first and then reg clmax2 = shortest - gadget_paddingLen - 1 # Try to put the cst_addr in addr_reg addr_to_reg = search(QueryType.CSTtoREG, addr_reg, adjusted_cst_addr, constr, a, clmax=clmax2, n=1, optimizeLen=optimizeLen) if (not addr_to_reg): tried_cst_addr.append((addr_reg, addr_cst)) continue else: clmax2 = clmax2 - len(addr_to_reg[0]) # Try to put the value into reg if (not value_is_reg): value_to_reg = search(qtype2, reg, adjusted_value, constr.add(RegsNotModified([addr_reg])), a, clmax=clmax2, n=1, optimizeLen=optimizeLen) if (value_to_reg): # If we found a solution # Combine them and return # Padd the gadget res = addr_to_reg[0].addChain(value_to_reg[0]).addGadget(gadget) if (gadget.spInc > 0): padding_value = constr.getValidPadding(Arch.octets()) res = res.addPadding(padding_value, n=(gadget.spInc / Arch.octets()) - 1) if (optimizeLen): if (best): best = min(best, res) else: best = res shortest = len(best) else: return res # 5 = two pops for addr_reg and reg + 1 for the write gadget # So since 5 is the shortest possible with two pops we can return # We can have < 5 if reg is already equal to 'value' argument # But we try this case first (see sorted()) when getting possibleWrites ;) if (((not optimizeLen) or (not value_is_reg)) and (not best is None) and len(best) <= 5): return best elif (optimizeLen and (not best is None) and len(best) <= 3): return best return best
def _basic(qtype, arg1, arg2, constraint, assertion, n=1, clmax=LMAX, noPadding=False): """ Search for gadgets basic method ( without chaining ) Direct Database check """ # Test clmax if (clmax <= 0): return [] if (not noPadding): maxSpInc = clmax * Arch.octets() else: maxSpInc = None # Check for special gadgets if (qtype == QueryType.INT80 or qtype == QueryType.SYSCALL): gadgets = DBSearch(qtype, arg1, arg2, constraint, assertion, n=1, maxSpInc=maxSpInc) res = [ROPChain().addGadget(g) for g in gadgets] return res # Check if the type is IP <- ... # In this case we remove the CHAINABLE constraint which makes no sense if (arg1 == Arch.ipNum()): constraint2 = constraint.remove([CstrTypeID.CHAINABLE]) else: constraint2 = constraint # Check to add assertions when looking for Memory gadgets if (qtype == QueryType.CSTtoMEM or qtype == QueryType.REGtoMEM): assertion2 = assertion.add(RegsNoOverlap([(arg1[0], Arch.spNum())])) else: assertion2 = assertion # Regular gadgets # maxSpInc -> +1 because we don't count the ret but -1 because the gadget takes one place gadgets = DBSearch(qtype, arg1, arg2, constraint2, assertion2, n, maxSpInc=maxSpInc) if (noPadding): return [ROPChain().addGadget(g) for g in gadgets] else: res = [] padding = constraint2.getValidPadding(Arch.currentArch.octets) for g in gadgets: chain = ROPChain().addGadget(g) # Padding the chain if possible if (g.spInc > 0): for i in range(0, g.spInc / Arch.octets() - 1): chain.addPadding(padding) # Adding to the result res.append(chain) return res