class CodeAnalyzerTests(unittest.TestCase): def setUp(self): self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32) self._smt_solver = SmtSolver() self._smt_translator = SmtTranslator(self._smt_solver, self._arch_info.address_size) self._smt_translator.set_arch_alias_mapper( self._arch_info.alias_mapper) self._smt_translator.set_arch_registers_size( self._arch_info.registers_size) self._x86_parser = X86Parser(ARCH_X86_MODE_32) self._x86_translator = X86Translator(ARCH_X86_MODE_32) self._code_analyzer = CodeAnalyzer(self._smt_solver, self._smt_translator, self._arch_info) def test_add_reg_reg(self): # Parser x86 instructions. asm_instrs = [self._x86_parser.parse(i) for i in [ "add eax, ebx", ]] # Add REIL instruction to the analyzer. for reil_instr in self.__asm_to_reil(asm_instrs): self._code_analyzer.add_instruction(reil_instr) # Add constraints. eax_pre = self._code_analyzer.get_register_expr("eax", mode="pre") eax_post = self._code_analyzer.get_register_expr("eax", mode="post") constraints = [ eax_pre != 42, # Pre-condition eax_post == 42, # Post-condition ] for constr in constraints: self._code_analyzer.add_constraint(constr) # Assertions. self.assertEqual(self._code_analyzer.check(), 'sat') self.assertNotEqual(self._code_analyzer.get_expr_value(eax_pre), 42) self.assertEqual(self._code_analyzer.get_expr_value(eax_post), 42) def test_add_reg_mem(self): # Parser x86 instructions. asm_instrs = [ self._x86_parser.parse(i) for i in [ "add eax, [ebx + 0x1000]", ] ] # Add REIL instruction to the analyzer. for reil_instr in self.__asm_to_reil(asm_instrs): self._code_analyzer.add_instruction(reil_instr) # Add constraints. eax_pre = self._code_analyzer.get_register_expr("eax", mode="pre") eax_post = self._code_analyzer.get_register_expr("eax", mode="post") constraints = [ eax_pre != 42, # Pre-condition eax_post == 42, # Post-condition ] for constr in constraints: self._code_analyzer.add_constraint(constr) # Assertions self.assertEqual(self._code_analyzer.check(), 'sat') self.assertNotEqual(self._code_analyzer.get_expr_value(eax_pre), 42) self.assertEqual(self._code_analyzer.get_expr_value(eax_post), 42) def test_add_mem_reg(self): # Parser x86 instructions. asm_instrs = [ self._x86_parser.parse(i) for i in [ "add [eax + 0x1000], ebx", ] ] # Add REIL instruction to the analyzer. for reil_instr in self.__asm_to_reil(asm_instrs): self._code_analyzer.add_instruction(reil_instr) # Add constraints. eax_pre = self._code_analyzer.get_register_expr("eax", mode="pre") mem_pre = self._code_analyzer.get_memory(mode="pre") mem_post = self._code_analyzer.get_memory(mode="post") constraints = [ mem_pre[eax_pre + 0x1000] != 42, # Pre-condition mem_post[eax_pre + 0x1000] == 42, # Post-condition ] for constr in constraints: self._code_analyzer.add_constraint(constr) # Assertions. self.assertEqual(self._code_analyzer.check(), 'sat') self.assertNotEqual( self._code_analyzer.get_expr_value(mem_pre[eax_pre + 0x1000]), 42) self.assertEqual( self._code_analyzer.get_expr_value(mem_post[eax_pre + 0x1000]), 42) def __asm_to_reil(self, instructions): # Set address for each instruction. for addr, asm_instr in enumerate(instructions): asm_instr.address = addr # Translate to REIL instructions. reil_instrs = [self._x86_translator.translate(i) for i in instructions] # Flatten list and return reil_instrs = [instr for instrs in reil_instrs for instr in instrs] return reil_instrs
class SymExecResult(object): def __init__(self, arch, initial_state, path, final_state): self.__initial_state = initial_state self.__path = path self.__final_state = final_state self.__arch = arch self.__smt_solver = None self.__smt_translator = None self.__code_analyzer = None self.__initialize_analyzer() self.__setup_solver() def query_register(self, register): # TODO: This method should return an iterator. smt_expr = self.__code_analyzer.get_register_expr(register, mode="pre") value = self.__code_analyzer.get_expr_value(smt_expr) return value def query_memory(self, address, size): # TODO: This method should return an iterator. smt_expr = self.__code_analyzer.get_memory_expr(address, size, mode="pre") value = self.__code_analyzer.get_expr_value(smt_expr) return value # Auxiliary methods # ======================================================================== # def __initialize_analyzer(self): self.__smt_solver = Z3Solver() self.__smt_translator = SmtTranslator(self.__smt_solver, self.__arch.address_size) self.__smt_translator.set_arch_alias_mapper(self.__arch.alias_mapper) self.__smt_translator.set_arch_registers_size(self.__arch.registers_size) self.__code_analyzer = CodeAnalyzer(self.__smt_solver, self.__smt_translator, self.__arch) def __setup_solver(self): self.__set_initial_state(self.__initial_state) self.__add_trace_to_solver(self.__path) self.__set_final_state(self.__final_state) assert self.__code_analyzer.check() == "sat" def __set_initial_state(self, initial_state): # Set registers for reg, val in initial_state.get_registers().items(): smt_expr = self.__code_analyzer.get_register_expr(reg, mode="pre") self.__code_analyzer.add_constraint(smt_expr == val) # Set memory for addr, val in initial_state.get_memory().items(): smt_expr = self.__code_analyzer.get_memory_expr(addr, 1, mode="pre") self.__code_analyzer.add_constraint(smt_expr == val) # Set constraints for constr in initial_state.get_constraints(): self.__code_analyzer.add_constraint(constr) def __set_final_state(self, final_state): # Set registers for reg, val in final_state.get_registers().items(): smt_expr = self.__code_analyzer.get_register_expr(reg, mode="post") self.__code_analyzer.add_constraint(smt_expr == val) # Set memory for addr, val in final_state.get_memory().items(): smt_expr = self.__code_analyzer.get_memory_expr(addr, 1, mode="post") self.__code_analyzer.add_constraint(smt_expr == val) # Set constraints for constr in final_state.get_constraints(): self.__code_analyzer.add_constraint(constr) def __add_trace_to_solver(self, trace): for reil_instr, branch_taken in trace: if reil_instr.mnemonic == ReilMnemonic.JCC and isinstance(reil_instr.operands[0], ReilRegisterOperand): oprnd = reil_instr.operands[0] oprnd_expr = self.__code_analyzer.get_operand_expr(oprnd) branch_expr = oprnd_expr != 0x0 if branch_taken else oprnd_expr == 0x0 # logger.debug(" Branch: {:#010x}:{:02x} {:s} ({}) - {:s}".format(reil_instr.address >> 8, reil_instr.address & 0xff, reil_instr, branch_taken, branch_expr)) self.__code_analyzer.add_constraint(branch_expr) else: self.__code_analyzer.add_instruction(reil_instr)