Example #1
0
    def __initialize_analyzer(self):
        self.__smt_solver = Z3Solver()

        self.__smt_translator = SmtTranslator(self.__smt_solver, self.__arch.address_size)
        self.__smt_translator.set_arch_alias_mapper(self.__arch.alias_mapper)
        self.__smt_translator.set_arch_registers_size(self.__arch.registers_size)

        self.__code_analyzer = CodeAnalyzer(self.__smt_solver, self.__smt_translator, self.__arch)
    def setUp(self):
        self._address_size = 32
        self._parser = ReilParser()
        self._solver = SmtSolver()
        self._translator = SmtTranslator(self._solver, self._address_size)

        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._translator.set_arch_alias_mapper(self._arch_info.alias_mapper)
        self._translator.set_arch_registers_size(self._arch_info.registers_size)
Example #3
0
 def setUp(self):
     self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)
     self._operand_size = self._arch_info.operand_size
     self._memory = MemoryMock()
     self._smt_solver = SmtSolver()
     self._smt_translator = SmtTranslator(self._smt_solver, self._operand_size)
     self._smt_translator.set_arch_alias_mapper(self._arch_info.alias_mapper)
     self._smt_translator.set_arch_registers_size(self._arch_info.registers_size)
     self._disasm = X86Disassembler()
     self._ir_translator = X86Translator()
     self._bb_builder = BasicBlockBuilder(self._disasm, self._memory, self._ir_translator)
Example #4
0
    def setUp(self):

        self._arch_info = ArmArchitectureInformation(ARCH_ARM_MODE_ARM)
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver, self._arch_info.address_size)

        self._ir_emulator = ReilEmulator(self._arch_info)

        self._smt_translator.set_arch_alias_mapper(self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(self._arch_info.registers_size)

        self._code_analyzer = CodeAnalyzer(self._smt_solver, self._smt_translator, self._arch_info)

        self._g_classifier = GadgetClassifier(self._ir_emulator, self._arch_info)
        self._g_verifier = GadgetVerifier(self._code_analyzer, self._arch_info)
Example #5
0
    def __init__ (self, binary):
        self.elf = elffile.ELFFile(binary)

        if self.elf.elfclass == 32:
            self.arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)
        if self.elf.elfclass == 64:
            self.arch_info = X86ArchitectureInformation(ARCH_X86_MODE_64)

        self.emulator = ReilEmulator(self.arch_info.address_size)
        self.emulator.set_arch_registers(self.arch_info.registers_gp)
        self.emulator.set_arch_registers_size(self.arch_info.register_size)
        self.emulator.set_reg_access_mapper(self.arch_info.register_access_mapper())

        self.classifier = GadgetClassifier(self.emulator, self.arch_info)

        self.smt_solver = SmtSolver()
        self.smt_translator = SmtTranslator(self.smt_solver, self.arch_info.address_size)

        self.smt_translator.set_reg_access_mapper(self.arch_info.register_access_mapper())
        self.smt_translator.set_arch_registers_size(self.arch_info.register_size)

        self.code_analyzer = CodeAnalyzer(self.smt_solver, self.smt_translator)

        self.gadgets = {}
        self.classified_gadgets = {}

        self.regset = RegSet(self)
        self.ccf = CCFlag(self)
        self.ams = ArithmeticStore(self)
        self.memstr = MemoryStore(self)

        self.reil_translator = X86Translator(architecture_mode=self.arch_info.architecture_mode,
                                                  translation_mode=FULL_TRANSLATION)
Example #6
0
    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._smt_solver = SmtSolver()

        self._smt_translator = SmtTranslator(self._smt_solver,
                                             self._arch_info.address_size)
        self._smt_translator.set_arch_alias_mapper(
            self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(
            self._arch_info.registers_size)

        self._x86_parser = X86Parser(ARCH_X86_MODE_32)

        self._x86_translator = X86Translator(ARCH_X86_MODE_32)

        self._code_analyzer = CodeAnalyzer(self._smt_solver,
                                           self._smt_translator,
                                           self._arch_info)
 def setUp(self):
     self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)
     self._operand_size = self._arch_info.operand_size
     self._memory = MemoryMock()
     self._smt_solver = SmtSolver()
     self._smt_translator = SmtTranslator(self._smt_solver, self._operand_size)
     self._smt_translator.set_arch_alias_mapper(self._arch_info.alias_mapper)
     self._smt_translator.set_arch_registers_size(self._arch_info.registers_size)
     self._disasm = X86Disassembler()
     self._ir_translator = X86Translator()
     self._bb_builder = BasicBlockBuilder(self._disasm, self._memory, self._ir_translator, self._arch_info)
    def setUp(self):

        self._arch_info = ArmArchitectureInformation(ARCH_ARM_MODE_ARM)
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver, self._arch_info.address_size)

        self._ir_emulator = ReilEmulator(self._arch_info)

        self._smt_translator.set_arch_alias_mapper(self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(self._arch_info.registers_size)

        self._code_analyzer = CodeAnalyzer(self._smt_solver, self._smt_translator, self._arch_info)

        self._g_classifier = GadgetClassifier(self._ir_emulator, self._arch_info)
        self._g_verifier = GadgetVerifier(self._code_analyzer, self._arch_info)
Example #9
0
    def setUp(self):
        self.trans_mode = FULL_TRANSLATION

        self.arch_mode = ARCH_X86_MODE_64

        self.arch_info = X86ArchitectureInformation(self.arch_mode)

        self.x86_parser = X86Parser(self.arch_mode)
        self.x86_translator = X86Translator(self.arch_mode, self.trans_mode)
        self.smt_solver = SmtSolver()
        self.smt_translator = SmtTranslator(self.smt_solver, self.arch_info.address_size)
        self.reil_emulator = ReilEmulator(self.arch_info.address_size)

        self.reil_emulator.set_arch_registers(self.arch_info.registers_gp)
        self.reil_emulator.set_arch_registers_size(self.arch_info.register_size)
        self.reil_emulator.set_reg_access_mapper(self.arch_info.register_access_mapper())

        self.smt_translator.set_reg_access_mapper(self.arch_info.register_access_mapper())
        self.smt_translator.set_arch_registers_size(self.arch_info.register_size)
Example #10
0
 def setUp(self):
     self._address_size = 32
     self._parser = ReilParser()
     self._solver = SmtSolver()
     self._translator = SmtTranslator(self._solver, self._address_size)
Example #11
0
 def setUp(self):
     self._address_size = 32
     self._parser = ReilParser()
     self._solver = SmtSolver()
     self._translator = SmtTranslator(self._solver, self._address_size)
Example #12
0
class SmtTranslatorTests(unittest.TestCase):

    def setUp(self):
        self._address_size = 32
        self._parser = ReilParser()
        self._solver = SmtSolver()
        self._translator = SmtTranslator(self._solver, self._address_size)

    def test_add_reg_reg(self):
        if VERBOSE:
            print "\n[+] Test: test_add_reg_reg"

        # add eax, ebx
        instrs = self._parser.parse([
            "add [eax, ebx, t0]",
            "str [t0, e, eax]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[1].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[2].size = 32

        self._solver.reset()

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("eax")) == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("eax")) != 42,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("ebx"))

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_add_reg_mem(self):
        if VERBOSE:
            print "\n[+] Test: test_add_reg_mem"

        # add eax, [ebx]
        instrs = self._parser.parse([
            "ldm [ebx, EMPTY, t0]",
            "add [eax, t0, t1]",
            "str [t1, EMPTY, eax]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[1].size = 32
        instrs[1].operands[2].size = 32

        instrs[2].operands[0].size = 32
        instrs[2].operands[2].size = 32

        self._solver.reset()

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("eax")) == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("eax")) != 42,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("ebx"))

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_add_mem_reg(self):
        if VERBOSE:
            print "\n[+] Test: test_add_mem_reg"

        # add [eax], ebx
        instrs = self._parser.parse([
            "ldm [eax, EMPTY, t0]",
            "add [t0, ebx, t1]",
            "stm [t1, EMPTY, eax]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[1].size = 32
        instrs[1].operands[2].size = 32

        instrs[2].operands[0].size = 32
        instrs[2].operands[2].size = 32

        self._solver.reset()

        # add constrains
        mem = self._translator.get_memory()
        eax = self._solver.mkBitVec(32, "eax_0")

        constraint = (mem[eax] != 42)

        if VERBOSE:
            print "constraint : %s" % constraint

        self._solver.add(constraint)

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            mem[eax] == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("t0")) != 42,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("ebx"))
                print "    t0 : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("t0"))
                print "    [eax] : 0x%08x" % self._solver.getvalue(mem[eax])

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_add_mem_reg_2(self):
        if VERBOSE:
            print "\n[+] Test: test_add_mem_reg_2"

        # add [eax + 0x1000], ebx
        instrs = self._parser.parse([
            "add [eax, 0x1000, t0]",
            "ldm [t0, e, t1]",
            "add [t1, ebx, t2]",
            "stm [t2, e, t0]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[1].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[2].size = 32

        instrs[2].operands[0].size = 32
        instrs[2].operands[1].size = 32
        instrs[2].operands[2].size = 32

        instrs[3].operands[0].size = 32
        instrs[3].operands[2].size = 32

        self._solver.reset()

        # add constrains
        mem = self._translator.get_memory()
        eax = self._solver.mkBitVec(32, "eax_0")
        off = BitVec(32, "#x%08x" % 0x1000)

        constraint = (mem[eax + off] != 42)

        if VERBOSE:
            print "constraint : %s" % constraint

        self._solver.add(constraint)

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            mem[eax + off] == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("t1")) != 42,
        ]

        if VERBOSE:
            print "[+] Constraints :"

            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("ebx"))
                print "    t0 : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("t0"))
                print "    t1 : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("t1"))
                print "    [eax + off] : 0x%08x" % self._solver.getvalue(mem[eax + off])

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_mul(self):
        if VERBOSE:
            print "\n[+] Test: test_mul"

        instrs = self._parser.parse([
            "mul [0x0, 0x1, t0]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[1].size = 32

        # TODO: Ver esto, el tam del output deberia ser 64
        instrs[0].operands[2].size = 32

        self._solver.reset()

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t0")) == 0,
            self._solver.mkBitVec(32, self._translator.get_init_name("t0")) != 0,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    t0 : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("t0"))

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_sext_1(self):
        instr = self._parser.parse(["sext [WORD 0xffff, EMPTY, DWORD t1]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) != 0xffffffff,
        ]

        self._solver.add(constraints[0])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_sext_2(self):
        instr = self._parser.parse(["sext [WORD 0x7fff, EMPTY, DWORD t1]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) != 0x00007fff,
        ]

        self._solver.add(constraints[0])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_bsh_left_1(self):
        instr = self._parser.parse(["bsh [DWORD t1, DWORD 16, QWORD t2]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) == 0xffffffff,
            self._solver.mkBitVec(64, self._translator.get_curr_name("t2")) != 0x0000ffffffff0000,
        ]

        self._solver.add(constraints[0])
        self._solver.add(constraints[1])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_bsh_left_2(self):
        instr = self._parser.parse(["bsh [DWORD t1, DWORD 16, DWORD t2]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) == 0xffffffff,
            self._solver.mkBitVec(32, self._translator.get_curr_name("t2")) != 0xffff0000,
        ]

        self._solver.add(constraints[0])
        self._solver.add(constraints[1])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_bsh_left_3(self):
        instr = self._parser.parse(["bsh [DWORD t1, DWORD 16, WORD t2]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) == 0xffffffff,
            self._solver.mkBitVec(16, self._translator.get_curr_name("t2")) != 0x0000,
        ]

        self._solver.add(constraints[0])
        self._solver.add(constraints[1])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_bsh_right_1(self):
        instr = self._parser.parse(["bsh [DWORD t1, DWORD -16, QWORD t2]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) == 0xffffffff,
            self._solver.mkBitVec(64, self._translator.get_curr_name("t2")) != 0x000000000000ffff,
        ]

        self._solver.add(constraints[0])
        self._solver.add(constraints[1])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_bsh_right_2(self):
        instr = self._parser.parse(["bsh [DWORD t1, DWORD -16, DWORD t2]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) == 0xffffffff,
            self._solver.mkBitVec(32, self._translator.get_curr_name("t2")) != 0x0000ffff,
        ]

        self._solver.add(constraints[0])
        self._solver.add(constraints[1])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_bsh_right_3(self):
        instr = self._parser.parse(["bsh [DWORD t1, DWORD -16, WORD t2]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) == 0xffffffff,
            self._solver.mkBitVec(16, self._translator.get_curr_name("t2")) != 0xffff,
        ]

        self._solver.add(constraints[0])
        self._solver.add(constraints[1])

        self.assertEqual(self._solver.check(), 'unsat')
Example #13
0
class State(object):

    def __init__(self, arch, mode=None):
        self._registers = {}
        self._memory = {}
        self._constraints = []
        self._mode = mode   # {"initial", "final"}

        self.__arch = arch

        self.__smt_solver = None
        self.__smt_translator = None

        self.__code_analyzer = None

        self.__initialize_analyzer()

    def read_register(self, register):
        return self._registers.get(register, None)

    def write_register(self, register, value):
        self._registers[register] = value

    def query_register(self, register):
        smt_expr = self.__code_analyzer.get_register_expr(register, mode="pre")

        return smt_expr

    def get_registers(self):
        return dict(self._registers)

    def read_memory(self, address, size):
        assert size == 1

        return self._memory.get(address, None)

    def write_memory(self, address, size, value):
        for i in range(0, size):
            self._memory[address + i] = (value >> (i * 8)) & 0xff

    def query_memory(self, address, size):
        smt_expr = self.__code_analyzer.get_memory_expr(address, size, mode="pre")

        return smt_expr

    def get_memory(self):
        return dict(self._memory)

    def add_constraint(self, constraint):
        self._constraints.append(constraint)

    def get_constraints(self):
        return list(self._constraints)

    # Auxiliary methods
    # ======================================================================== #
    def __initialize_analyzer(self):
        self.__smt_solver = Z3Solver()

        self.__smt_translator = SmtTranslator(self.__smt_solver, self.__arch.address_size)
        self.__smt_translator.set_arch_alias_mapper(self.__arch.alias_mapper)
        self.__smt_translator.set_arch_registers_size(self.__arch.registers_size)

        self.__code_analyzer = CodeAnalyzer(self.__smt_solver, self.__smt_translator, self.__arch)
Example #14
0
class SmtTranslatorTests(unittest.TestCase):
    def setUp(self):
        self._address_size = 32
        self._parser = ReilParser()
        self._solver = SmtSolver()
        self._translator = SmtTranslator(self._solver, self._address_size)

        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._translator.set_arch_alias_mapper(self._arch_info.alias_mapper)
        self._translator.set_arch_registers_size(
            self._arch_info.registers_size)

    # Arithmetic Instructions
    def test_translate_add_1(self):
        # Same size operands.
        instr = self._parser.parse(["add [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvadd t0_0 t1_0))")

    def test_translate_add_2(self):
        # Destination operand larger than source operands.
        instr = self._parser.parse(["add [BYTE t0, BYTE t1, WORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= t2_1 (bvadd ((_ zero_extend 8) t0_0) ((_ zero_extend 8) t1_0)))"
        )

    def test_translate_add_3(self):
        # Destination operand smaller than source operands.
        instr = self._parser.parse(["add [WORD t0, WORD t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value,
                         "(= t2_1 ((_ extract 7 0) (bvadd t0_0 t1_0)))")

    def test_translate_add_4(self):
        # Mixed source operands.
        instr = self._parser.parse(["add [BYTE t0, BYTE 0x12, WORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= t2_1 (bvadd ((_ zero_extend 8) t0_0) ((_ zero_extend 8) #x12)))"
        )

    def test_translate_sub(self):
        instr = self._parser.parse(["sub [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvsub t0_0 t1_0))")

    def test_translate_mul(self):
        instr = self._parser.parse(["mul [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvmul t0_0 t1_0))")

    def test_translate_div(self):
        instr = self._parser.parse(["div [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvudiv t0_0 t1_0))")

    def test_translate_mod(self):
        instr = self._parser.parse(["mod [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvurem t0_0 t1_0))")

    def test_translate_bsh(self):
        instr = self._parser.parse(["bsh [DWORD t0, DWORD t1, DWORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= t2_1 (ite (bvsge t1_0 #x00000000) (bvshl t0_0 t1_0) (bvlshr t0_0 (bvneg t1_0))))"
        )

    # Bitwise Instructions
    def test_translate_and(self):
        instr = self._parser.parse(["and [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvand t0_0 t1_0))")

    def test_translate_or(self):
        instr = self._parser.parse(["or [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvor t0_0 t1_0))")

    def test_translate_xor(self):
        instr = self._parser.parse(["xor [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvxor t0_0 t1_0))")

    # Data Transfer Instructions
    def test_translate_ldm(self):
        instr = self._parser.parse(["ldm [DWORD t0, empty, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value,
                         "(= (select MEM_0 (bvadd t0_0 #x00000000)) t2_1)")

    def test_translate_stm(self):
        instr = self._parser.parse(["stm [BYTE t0, empty, DWORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= MEM_1 (store MEM_0 (bvadd t2_0 #x00000000) t0_0))")

    def test_translate_str(self):
        instr = self._parser.parse(["str [BYTE t0, empty, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 t0_0)")

    # Conditional Instructions
    def test_translate_bisz(self):
        instr = self._parser.parse(["bisz [DWORD t0, empty, DWORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= t2_1 (ite (= t0_0 #x00000000) #x00000001 #x00000000))")

    def test_translate_jcc(self):
        instr = self._parser.parse(["jcc [BIT t0, empty, DWORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(not (= t0_0 #b0))")

    # Other Instructions
    def test_translate_undef(self):
        instr = self._parser.parse(["undef [empty, empty, DWORD t2]"])[0]

        with self.assertRaises(Exception) as context:
            self._translator.translate(instr)

        self.assertTrue("Unsupported instruction : UNDEF" in context.exception)

    def test_translate_unkn(self):
        instr = self._parser.parse(["unkn [empty, empty, empty]"])[0]

        with self.assertRaises(Exception) as context:
            self._translator.translate(instr)

        self.assertTrue("Unsupported instruction : UNKN" in context.exception)

    def test_translate_nop(self):
        instr = self._parser.parse(["nop [empty, empty, empty]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 0)

    # Extensions
    def test_translate_sext(self):
        instr = self._parser.parse(["sext [BYTE t0, empty, WORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 ((_ sign_extend 8) t0_0))")

    def test_translate_sdiv(self):
        instr = self._parser.parse(["sdiv [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvsdiv t0_0 t1_0))")

    def test_translate_smod(self):
        instr = self._parser.parse(["smod [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvsmod t0_0 t1_0))")
Example #15
0
class GadgetTools():
    def __init__ (self, binary):
        self.elf = elffile.ELFFile(binary)

        if self.elf.elfclass == 32:
            self.arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)
        if self.elf.elfclass == 64:
            self.arch_info = X86ArchitectureInformation(ARCH_X86_MODE_64)

        self.emulator = ReilEmulator(self.arch_info.address_size)
        self.emulator.set_arch_registers(self.arch_info.registers_gp)
        self.emulator.set_arch_registers_size(self.arch_info.register_size)
        self.emulator.set_reg_access_mapper(self.arch_info.register_access_mapper())

        self.classifier = GadgetClassifier(self.emulator, self.arch_info)

        self.smt_solver = SmtSolver()
        self.smt_translator = SmtTranslator(self.smt_solver, self.arch_info.address_size)

        self.smt_translator.set_reg_access_mapper(self.arch_info.register_access_mapper())
        self.smt_translator.set_arch_registers_size(self.arch_info.register_size)

        self.code_analyzer = CodeAnalyzer(self.smt_solver, self.smt_translator)

        self.gadgets = {}
        self.classified_gadgets = {}

        self.regset = RegSet(self)
        self.ccf = CCFlag(self)
        self.ams = ArithmeticStore(self)
        self.memstr = MemoryStore(self)

        self.reil_translator = X86Translator(architecture_mode=self.arch_info.architecture_mode,
                                                  translation_mode=FULL_TRANSLATION)




    def find_gadgets(self, max_instr=10, max_bytes=15):
        logging.info('searching gadgets in binary..')
        for s in self.elf.iter_sections():
            if (s.header.sh_type == 'SHT_PROGBITS') and (s.header.sh_flags & 0x4):
                sz = s.header.sh_size
                base = s.header.sh_addr
                mem = Memory(lambda x, y : s.data()[x - base], None)
                gfinder = GadgetFinder(X86Disassembler(architecture_mode=self.arch_info.architecture_mode),
                                    mem,
                                    self.reil_translator)

                logging.info("searching gadgets in section " + s.name + "...")

                for g in gfinder.find(base, base + sz - 1, max_instr, max_bytes):
                    ret = g.instrs[-1].asm_instr
                    if not isinstance(ret, Ret):
                        continue
                    if len(ret.operands) > 0 and ret.operands[0].immediate > 0x10:
                        continue

                    self.gadgets[g.address] = g

                logging.info("found {0} gadgets".format(len(self.gadgets)))

    def classify_gadgets(self):
        #setting 0 to cf flags to avoid impossible random value invalidates results

        rflags =  self.classifier.set_reg_init({'cf': 0})

        for g in self.gadgets.itervalues():
            tgs = self.classifier.classify(g)
            for tg in tgs:
                if tg.type not in self.classified_gadgets:
                    self.classified_gadgets[tg.type] = []

                self.classified_gadgets[tg.type].append(tg)



    def find_reg_set_gadgets(self):
        self.regset.add(self.gadgets.itervalues())


    def read_carrier_flag(self, g):
        self.emulator.reset()
        regs_init = utils.make_random_regs_context(self.arch_info)

        self.emulator.execute_lite(g.get_ir_instrs(), regs_init)

        return 'eflags' in self.emulator.registers


    def find_arithmetic_mem_set_gadgets(self):
        self.ams.add(self.classified_gadgets[GadgetType.ArithmeticStore])

    def get_stack_slide_chunk(self, slide):
        #TODO not use only pop gadgets and chain more slide if we wan't longer slide
        return self.regset.get_slide_stack_chunk(slide)


    def get_ret_func_chunk(self, args, address):
        """Return a chainable chunk that return to address and set up args or registers as if a function was called with args.

        Args:

        args (list): the list of args to setup as function arguments
        address (int): the address where to return

        """
        if self.arch_info.architecture_size == 64:
            if len(args) > 6:
                raise BaseException("chunk for calling a function whit more of six args isn't implemented")

            args_regs = ['rdi', 'rsi', 'rdx', 'rcx', 'r8', 'r9']
            regs_values = {args_regs[i] : a for i, a in enumerate(args)}
            regs_c = self.regset.get_chunk(regs_values)
            ret_c = PayloadChunk("", self.arch_info, address)
            return PayloadChunk.get_general_chunk([regs_c, ret_c])

        if self.arch_info.architecture_size == 32:
            slide_c = self.regset.get_slide_stack_chunk(len(args) * 4)
            print slide_c
            ret_c = RetToAddress32(args, address, self.arch_info)
            print ret_c
            return PayloadChunk.get_general_chunk([ret_c, slide_c])

    def get_mem_set_libc_read_chunk(self, location, fd, size, read_address):
        if self.arch_info.architecture_size == 64:
            print "TO IMPLEMENT"
            return


        if self.arch_info.architecture_size == 32:
            slide_chunk = self.regset.get_slide_stack_chunk(4 * 3)
            pl_chunk = MemSetLibcRead32(location,
                                        fd,
                                        size,
                                        read_address,
                                        self.arch_info)


        return PayloadChunk.get_general_chunk([pl_chunk, slide_chunk])

    def build_mem_add(self, location, offset, size, mem_pre = None):
        return self.ams.get_memory_add_chunk(location, offset, size, mem_pre)


    def check_mem_side_effects_and_stack_end(self, g, regs_init, location, size):
        stack_reg = 'esp'
        if (self.arch_info.architecture_size == 64):
            stack_reg = 'rsp'

        stack_base = 0x50
        regs_init[stack_reg] = stack_base

        #TODO fix try and execute (zero div in mv where finding ccf)
        try:
            cregs, mem_final = self.emulator.execute_lite(g.get_ir_instrs(), regs_init)
        except:
            pass

        mem_side_effects = []

        for addr in mem_final.get_addresses():
            if addr in [location + i for i in xrange(size/8)]:
                continue

            sp, vp = mem_final.try_read_prev(addr, 8)
            sn, vn = mem_final.try_read(addr, 8)

            #quick fix. We should disting between reading from stack and read side effetcs
            if sn and not sp and (addr >= stack_base - abs(stack_base - cregs[stack_reg]) and addr <= stack_base + abs(stack_base - cregs[stack_reg])):
                continue

            if (sp and sn and vp != vn) or (sn and not sp) :
                mem_side_effects.append(addr)

        return mem_side_effects, cregs[stack_reg] - stack_base - self.arch_info.address_size / 8

    def find_memory_store(self):
        self.mem_set_gadgets = {}

        if not GadgetType.StoreMemory in self.classified_gadgets:
            return

        self.memstr.add(self.classified_gadgets[GadgetType.StoreMemory])

    def find_ccfs(self):
        self.ccf.add(self.gadgets.values())
class CodeAnalyzerTests(unittest.TestCase):

    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)
        self._operand_size = self._arch_info.operand_size
        self._memory = MemoryMock()
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver, self._operand_size)
        self._smt_translator.set_arch_alias_mapper(self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(self._arch_info.registers_size)
        self._disasm = X86Disassembler()
        self._ir_translator = X86Translator()
        self._bb_builder = BasicBlockBuilder(self._disasm, self._memory, self._ir_translator, self._arch_info)

    def test_check_path_satisfiability(self):
        if VERBOSE:
            print "[+] Test: test_check_path_satisfiability"

        # binary : stack1
        bin_start_address, bin_end_address = 0x08048ec0, 0x8048f02

        binary  = "\x55"                          # 0x08048ec0 : push   ebp
        binary += "\x89\xe5"                      # 0x08048ec1 : mov    ebp,esp
        binary += "\x83\xec\x60"                  # 0x08048ec3 : sub    esp,0x60
        binary += "\x8d\x45\xfc"                  # 0x08048ec6 : lea    eax,[ebp-0x4]
        binary += "\x89\x44\x24\x08"              # 0x08048ec9 : mov    DWORD PTR [esp+0x8],eax
        binary += "\x8d\x45\xac"                  # 0x08048ecd : lea    eax,[ebp-0x54]
        binary += "\x89\x44\x24\x04"              # 0x08048ed0 : mov    DWORD PTR [esp+0x4],eax
        binary += "\xc7\x04\x24\xa8\x5a\x0c\x08"  # 0x08048ed4 : mov    DWORD PTR [esp],0x80c5aa8
        binary += "\xe8\xa0\x0a\x00\x00"          # 0x08048edb : call   8049980 <_IO_printf>
        binary += "\x8d\x45\xac"                  # 0x08048ee0 : lea    eax,[ebp-0x54]
        binary += "\x89\x04\x24"                  # 0x08048ee3 : mov    DWORD PTR [esp],eax
        binary += "\xe8\xc5\x0a\x00\x00"          # 0x08048ee6 : call   80499b0 <_IO_gets>
        binary += "\x8b\x45\xfc"                  # 0x08048eeb : mov    eax,DWORD PTR [ebp-0x4]
        binary += "\x3d\x44\x43\x42\x41"          # 0x08048eee : cmp    eax,0x41424344
        binary += "\x75\x0c"                      # 0x08048ef3 : jne    8048f01 <main+0x41>
        binary += "\xc7\x04\x24\xc0\x5a\x0c\x08"  # 0x08048ef5 : mov    DWORD PTR [esp],0x80c5ac0
        binary += "\xe8\x4f\x0c\x00\x00"          # 0x08048efc : call   8049b50 <_IO_puts>
        binary += "\xc9"                          # 0x08048f01 : leave
        binary += "\xc3"                          # 0x08048f02 : ret

        self._memory.set_base_address(bin_start_address)
        self._memory.set_content(binary)

        start = 0x08048ec0
        # start = 0x08048ec6
        # end = 0x08048efc
        end = 0x08048f01

        registers = {
            "eax" : GenericRegister("eax", 32, 0xffffd0ec),
            "ecx" : GenericRegister("ecx", 32, 0x00000001),
            "edx" : GenericRegister("edx", 32, 0xffffd0e4),
            "ebx" : GenericRegister("ebx", 32, 0x00000000),
            "esp" : GenericRegister("esp", 32, 0xffffd05c),
            "ebp" : GenericRegister("ebp", 32, 0x08049580),
            "esi" : GenericRegister("esi", 32, 0x00000000),
            "edi" : GenericRegister("edi", 32, 0x08049620),
            "eip" : GenericRegister("eip", 32, 0x08048ec0),
        }

        flags = {
            "af" : GenericFlag("af", 0x0),
            "cf" : GenericFlag("cf", 0x0),
            "of" : GenericFlag("of", 0x0),
            "pf" : GenericFlag("pf", 0x1),
            "sf" : GenericFlag("sf", 0x0),
            "zf" : GenericFlag("zf", 0x1),
        }

        memory = {
        }

        bb_list = self._bb_builder.build(bin_start_address, bin_end_address)

        bb_graph = BasicBlockGraph(bb_list)
        # bb_graph.save("bb_graph.png")
        # bb_graph.save("bb_graph_ir.png", print_ir=True)

        codeAnalyzer = CodeAnalyzer(self._smt_solver, self._smt_translator, self._arch_info)

        codeAnalyzer.set_context(GenericContext(registers, flags, memory))

        for bb_path in bb_graph.all_simple_bb_paths(start, end):
            if VERBOSE:
                print "[+] Checking path satisfiability :"
                print "      From : %s" % hex(start)
                print "      To : %s" % hex(end)
                print "      Path : %s" % " -> ".join((map(lambda o : hex(o.address), bb_path)))

            is_sat = codeAnalyzer.check_path_satisfiability(bb_path, start, verbose=False)

            if VERBOSE:
                print "[+] Satisfiability : %s" % str(is_sat)

            self.assertTrue(is_sat)

            if is_sat and VERBOSE:
                print codeAnalyzer.get_context()

            if VERBOSE:
                print ":" * 80
                print ""
class ArmGadgetClassifierTests(unittest.TestCase):

    def setUp(self):

        self._arch_info = ArmArchitectureInformation(ARCH_ARM_MODE_32)
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver, self._arch_info.address_size)
        self._ir_emulator = ReilEmulator(self._arch_info.address_size)

        self._ir_emulator.set_arch_registers(self._arch_info.registers_gp_all)
        self._ir_emulator.set_arch_registers_size(self._arch_info.registers_size)
        self._ir_emulator.set_reg_access_mapper(self._arch_info.alias_mapper)

        self._smt_translator.set_reg_access_mapper(self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(self._arch_info.registers_size)

        self._code_analyzer = CodeAnalyzer(self._smt_solver, self._smt_translator)

        self._g_classifier = GadgetClassifier(self._ir_emulator, self._arch_info)
        self._g_verifier = GadgetVerifier(self._code_analyzer, self._arch_info)

    def _find_and_classify_gadgets(self, binary):
        g_finder = GadgetFinder(ArmDisassembler(), binary, ArmTranslator(translation_mode=LITE_TRANSLATION), ARCH_ARM, ARCH_ARM_MODE_32)

        g_candidates = g_finder.find(0x00000000, len(binary), instrs_depth=4)
        g_classified = self._g_classifier.classify(g_candidates[0])

#         Debug:
#         self._print_candidates(g_candidates)
#         self._print_classified(g_classified)

        return g_candidates, g_classified

    def test_move_register_1(self):
        # testing : dst_reg <- src_reg
        binary  = "\x04\x00\xa0\xe1"                     # 0x00 : (4)  mov    r0, r4
        binary += "\x31\xff\x2f\xe1"                     # 0x04 : (4)  blx    r1

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[0].type, GadgetType.MoveRegister)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r4", 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 1)
        self.assertTrue(ReilRegisterOperand("r14", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_move_register_2(self):
        # testing : dst_reg <- src_reg
        binary  = "\x00\x00\x84\xe2"                     # 0x00 : (4)  add    r0, r4, #0
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr


        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.MoveRegister)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r4", 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    # TODO: test_move_register_n: mul r0, r4, #1

    def test_load_constant_1(self):
        # testing : dst_reg <- constant
        binary  = "\x0a\x20\xa0\xe3"                     # 0x00 : (4)  mov    r2, #10
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources, [ReilImmediateOperand(10, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_2(self):
        # testing : dst_reg <- constant
        binary  = "\x02\x20\x42\xe0"                     # 0x00 : (4)  sub    r2, r2, r2
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources, [ReilImmediateOperand(0, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_3(self):
        # testing : dst_reg <- constant
        binary  = "\x02\x20\x22\xe0"                     # 0x00 : (4)  eor    r2, r2, r2
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources, [ReilImmediateOperand(0, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_4(self):
        # testing : dst_reg <- constant
        binary  = "\x00\x20\x02\xe2"                     # 0x00 : (4)  and    r2, r2, #0
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources, [ReilImmediateOperand(0, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_5(self):
        # testing : dst_reg <- constant
        binary  = "\x00\x20\x02\xe2"                     # and    r2, r2, #0
        binary += "\x21\x20\x82\xe3"                     # orr    r2, r2, #33
        binary += "\x1e\xff\x2f\xe1"                     # bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources, [ReilImmediateOperand(33, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_add_1(self):
        # testing : dst_reg <- src1_reg + src2_reg
        binary  = "\x08\x00\x84\xe0"                     # 0x00 : (4)  add    r0, r4, r8
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.Arithmetic)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilRegisterOperand("r8", 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[0].operation, "+")

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_sub_1(self):
        # testing : dst_reg <- src1_reg + src2_reg
        binary  = "\x08\x00\x44\xe0"                     # 0x00 : (4)  sub    r0, r4, r8
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.Arithmetic)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilRegisterOperand("r8", 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[0].operation, "-")

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_memory_1(self):
        # testing : dst_reg <- m[src_reg]
        binary  = "\x00\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4]
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadMemory)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r3", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_memory_2(self):
        # testing : dst_reg <- m[src_reg + offset]
        binary  = "\x33\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4 + 0x33]
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadMemory)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x33, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r3", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    # TODO: ARM's ldr rd, [rn, r2] is not a valid classification right now

    def test_store_memory_1(self):
        # testing : dst_reg <- m[src_reg]
        binary  = "\x00\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4]
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.StoreMemory)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r3", 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_store_memory_2(self):
        # testing : dst_reg <- m[src_reg + offset]
        binary  = "\x33\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4 + 0x33]
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.StoreMemory)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r3", 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x33, 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_load_add_1(self):
        # testing : dst_reg <- dst_reg + mem[src_reg]
        binary  = "\x00\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4]
        binary += "\x03\x00\x80\xe0"                     # 0x00 : (4)  add    r0, r0, r3
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticLoad)
        self.assertEquals(g_classified[1].sources, [ReilRegisterOperand("r0", 32), ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])
        self.assertEquals(g_classified[1].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r0", 32) in g_classified[1].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def test_arithmetic_load_add_2(self):
        # testing : dst_reg <- dst_reg + mem[src_reg + offset]
        binary  = "\x22\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4, 0x22]
        binary += "\x03\x00\x80\xe0"                     # 0x00 : (4)  add    r0, r0, r3
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticLoad)
        self.assertEquals(g_classified[1].sources, [ReilRegisterOperand("r0", 32), ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x22, 32)])
        self.assertEquals(g_classified[1].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r0", 32) in g_classified[1].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def test_arithmetic_store_add_1(self):
        # testing : m[dst_reg] <- m[dst_reg] + src_reg
        binary  = "\x00\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4]
        binary += "\x03\x30\x80\xe0"                     # 0x00 : (4)  add    r3, r0, r3
        binary += "\x00\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4]
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticStore)
        self.assertEquals(g_classified[1].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32), ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[1].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r4", 32) in g_classified[1].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def test_arithmetic_store_add_2(self):
        # testing : dst_reg <- dst_reg + mem[src_reg + offset]
        binary  = "\x22\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4, 0x22]
        binary += "\x03\x30\x80\xe0"                     # 0x00 : (4)  add    r3, r0, r3
        binary += "\x22\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4, 0x22]
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticStore)
        self.assertEquals(g_classified[1].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x22, 32), ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[1].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x22, 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r4", 32) in g_classified[1].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def _print_candidates(self, candidates):
        print "Candidates :"

        for gadget in candidates:
            print gadget
            print "-" * 10

    def _print_classified(self, classified):
        print "Classified :"

        for gadget in classified:
            print gadget
            print gadget.type
            print "-" * 10
def barf_classify(gadget_map, printout=True):
    arch_mode = ARCH_X86_MODE_32
    arch_info = X86ArchitectureInformation(arch_mode)
    translator = X86Translator(arch_mode)
    instruction_parser = X86Parser(arch_mode)
    ir_emulator = ReilEmulator(arch_info)
    classifier = GadgetClassifier(ir_emulator, arch_info)
    raw_gadgets = {}
    typed_gadgets = []
    for _, gadget in gadget_map.items():

        # Translation cycle: from my emulator to BARF representation
        classifiable = False
        barf_instr_list = []
        for _, instr in gadget.instructions.items():
            # Parse a ROPInstruction into the BARF representation of an x86 instruction
            barf_instr = instruction_parser.parse("{} {}".format(
                instr.mnemonic, instr.op_str))
            barf_instr.address = instr.address
            try:
                # Translate an x86 instruction into a list of REIL instructions
                reil_transl_instrs = translator.translate(barf_instr)
                barf_instr.ir_instrs = reil_transl_instrs
                classifiable = True
            except TranslationError:
                classifiable = False
            finally:
                barf_instr_list.append(barf_instr)

        # Classification of the gadgets
        barf_g = RawGadget(barf_instr_list)
        raw_gadgets[barf_g.address] = barf_g
        if classifiable:
            classified = classifier.classify(barf_g)
            for tg in classified:
                typed_gadgets.append(tg)
    if printout:
        print_gadgets_raw(list(raw_gadgets.values()), sys.stdout, 'addr', True,
                          'Raw Gadgets', False)
        verified = []
        unverified = []
        solver = Z3Solver()
        translator = SmtTranslator(solver, arch_info.address_size)
        code_analyzer = CodeAnalyzer(solver, translator, arch_info)
        verifier = GadgetVerifier(code_analyzer, arch_info)
        for tg in typed_gadgets:
            if verifier.verify(tg):
                verified.append(tg)
            else:
                unverified.append(tg)
        print_gadgets_typed(verified, sys.stdout, arch_info.address_size,
                            'Verified classification')
        print_gadgets_typed(unverified, sys.stdout, arch_info.address_size,
                            'Unverified classification')
        for tg in typed_gadgets:
            if tg.address in raw_gadgets:
                raw_gadgets.pop(tg.address)
        print_gadgets_raw(list(raw_gadgets.values()), sys.stdout, 'addr',
                          False, 'Not classified', False)

    return {tg.address: tg for tg in typed_gadgets}
Example #19
0
class X86TranslationTests(unittest.TestCase):

    def setUp(self):
        self.trans_mode = FULL_TRANSLATION

        self.arch_mode = ARCH_X86_MODE_64

        self.arch_info = X86ArchitectureInformation(self.arch_mode)

        self.x86_parser = X86Parser(self.arch_mode)
        self.x86_translator = X86Translator(self.arch_mode, self.trans_mode)
        self.smt_solver = SmtSolver()
        self.smt_translator = SmtTranslator(self.smt_solver, self.arch_info.address_size)
        self.reil_emulator = ReilEmulator(self.arch_info.address_size)

        self.reil_emulator.set_arch_registers(self.arch_info.registers_gp)
        self.reil_emulator.set_arch_registers_size(self.arch_info.register_size)
        self.reil_emulator.set_reg_access_mapper(self.arch_info.register_access_mapper())

        self.smt_translator.set_reg_access_mapper(self.arch_info.register_access_mapper())
        self.smt_translator.set_arch_registers_size(self.arch_info.register_size)

    def test_lea(self):
        asm = ["lea eax, [ebx + 0x100]"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_cld(self):
        asm = ["cld"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_clc(self):
        asm = ["clc"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_nop(self):
        asm = ["nop"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_test(self):
        asm = ["test eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_not(self):
        asm = ["not eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_xor(self):
        asm = ["xor eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_or(self):
        asm = ["or eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_and(self):
        asm = ["and eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_cmp(self):
        asm = ["cmp eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_neg(self):
        asm = ["neg eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_dec(self):
        asm = ["dec eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_inc(self):
        asm = ["inc eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_div(self):
        asm = ["div ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = {
            'rax'    : 0x10,
            'rbx'    : 0x2,
            'rdx'    : 0x0,
            'rflags' : 0x202,
        }

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_imul(self):
        asm = ["imul eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_mul(self):
        asm = ["mul ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_sbb(self):
        asm = ["sbb eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_sub(self):
        asm = ["sub eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_adc(self):
        asm = ["adc eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_add(self):
        asm = ["add eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_xchg(self):
        asm = ["xchg eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_movzx(self):
        asm = ["movzx eax, bx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_mov(self):
        asm = ["mov eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_shr(self):
        asm = ["shr eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_shl(self):
        asm = ["shl eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_sal(self):
        asm = ["sal eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_sar(self):
        asm = ["sar eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)
        x86_instrs[0].address = 0xdeadbeef

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_stc(self):
        asm = ["stc"]

        x86_instrs = map(self.x86_parser.parse, asm)
        x86_instrs[0].address = 0xdeadbeef

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)

        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def __init_context(self):
        return {
            'rax'    : 0xa,
            'rbx'    : 0x2,
            'rcx'    : 0xb,
            'rdx'    : 0xc,
            'rdi'    : 0xd,
            'rsi'    : 0xe,
            'rflags' : 0x202,
        }

    def __compare_contexts(self, context_init, x86_context, reil_context):
        match = True

        fmt = "%s (x86)  : %s (%s)"

        mask = 2**64-1

        for reg in sorted(context_init.keys()):
            if ((2**64-1) & x86_context[reg]) != ((2**64-1) & reil_context[reg]):
                x86_value = x86_context[reg] & mask
                reil_value = reil_context[reg] & mask

                if reg in ['rflags', 'eflags']:
                    x86_flags_str = self.__print_flags(x86_context[reg])
                    reil_flags_str = self.__print_flags(reil_context[reg])

                    print ("%s (x86)  : %s (%s)" % (reg, hex(x86_value), x86_flags_str))
                    print ("%s (reil) : %s (%s)" % (reg, hex(reil_value), reil_flags_str))
                else:
                    print ("%s (x86)  : %s " % (reg, hex(x86_value)))
                    print ("%s (reil) : %s " % (reg, hex(reil_value)))

                match = False
                break

        if not match:
            self.__print_contexts(context_init, x86_context, reil_context)

        return match

    def __print_contexts(self, context_init, x86_context, reil_context):
        header_fmt = "{0:^8s} : {1:>16s} ?= {2:<16s}"
        header = header_fmt.format("Register", "x86", "REIL")
        ruler  = "-" * len(header)

        print(header)
        print(ruler)

        fmt = "{0:>8s} : {1:016x} {eq} {2:016x} ({1:>5d} {eq} {2:<5d}) {marker}"
        mask = 2**64-1

        for reg in sorted(context_init.keys()):
            if (x86_context[reg] & mask) != (reil_context[reg] & mask):
                eq = "!="
                marker = "<"
            else:
                eq = "=="
                marker = ""

            print fmt.format(
                reg,
                (2**64-1) & x86_context[reg],
                (2**64-1) & reil_context[reg],
                eq=eq,
                marker=marker
            )

    def __print_flags(self, flags_reg):
        # flags
        flags = {
             0 : "cf",  # bit 0
             2 : "pf",  # bit 2
             4 : "af",  # bit 4
             6 : "zf",  # bit 6
             7 : "sf",  # bit 7
            11 : "of",  # bit 11
            10 : "df",  # bit 10
        }

        out = ""

        for bit, flag in flags.items():
            if flags_reg & 2**bit:
                out += flag.upper() + " "
            else:
                out += flag.lower() + " "

        return out[:-1]

    def __fix_reil_flags(self, reil_context, x86_context):
        reil_context_out = dict(reil_context)

        flags_reg = 'eflags' if 'eflags' in reil_context_out else 'rflags'

        # Remove this when AF and PF are implemented.
        reil_context_out[flags_reg] |= (x86_context[flags_reg] & 2**4) # AF
        reil_context_out[flags_reg] |= (x86_context[flags_reg] & 2**2) # PF

        return reil_context_out
Example #20
0
class X86TranslationTests(unittest.TestCase):

    def setUp(self):
        self.trans_mode = FULL_TRANSLATION

        self.arch_mode = ARCH_X86_MODE_64

        self.arch_info = X86ArchitectureInformation(self.arch_mode)

        self.x86_parser = X86Parser(self.arch_mode)
        self.x86_translator = X86Translator(self.arch_mode, self.trans_mode)
        self.smt_solver = SmtSolver()
        self.smt_translator = SmtTranslator(self.smt_solver, self.arch_info.address_size)
        self.reil_emulator = ReilEmulator(self.arch_info.address_size)

        self.reil_emulator.set_arch_registers(self.arch_info.registers_gp)
        self.reil_emulator.set_arch_registers_size(self.arch_info.register_size)
        self.reil_emulator.set_reg_access_mapper(self.arch_info.register_access_mapper())

        self.smt_translator.set_reg_access_mapper(self.arch_info.register_access_mapper())
        self.smt_translator.set_arch_registers_size(self.arch_info.register_size)

    def test_lea(self):
        asm = ["lea eax, [ebx + 0x100]"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_cld(self):
        asm = ["cld"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_clc(self):
        asm = ["clc"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_nop(self):
        asm = ["nop"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_test(self):
        asm = ["test eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_not(self):
        asm = ["not eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_xor(self):
        asm = ["xor eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_or(self):
        asm = ["or eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_and(self):
        asm = ["and eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_cmp(self):
        asm = ["cmp eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_neg(self):
        asm = ["neg eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_dec(self):
        asm = ["dec eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_inc(self):
        asm = ["inc eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_div(self):
        asm = ["div ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = {
            'rax'    : 0x10,
            'rbx'    : 0x2,
            'rdx'    : 0x0,
            'rflags' : 0x202,
        }

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_imul(self):
        asm = ["imul eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_mul(self):
        asm = ["mul ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_sbb(self):
        asm = ["sbb eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_sub(self):
        asm = ["sub eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_adc(self):
        asm = ["adc eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_add(self):
        asm = ["add eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_xchg(self):
        asm = ["xchg eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_movzx(self):
        asm = ["movzx eax, bx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_mov(self):
        asm = ["mov eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_shr(self):
        asm = ["shr eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_shl(self):
        asm = ["shl eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_sal(self):
        asm = ["sal eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_sar(self):
        asm = ["sar eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)
        x86_instrs[0].address = 0xdeadbeef

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def __init_context(self):
        return {
            'rax'    : 0xa,
            'rbx'    : 0x2,
            'rcx'    : 0xb,
            'rdx'    : 0xc,
            'rdi'    : 0xd,
            'rsi'    : 0xe,
            'rflags' : 0x202,
        }

    def __compare_contexts(self, context_init, x86_context, reil_context):
        match = True

        for reg in sorted(context_init.keys()):
            if ((2**64-1) & x86_context[reg]) != ((2**64-1) & reil_context[reg]):
                print ("%s : %s " % (reg, hex((2**64-1) & x86_context[reg])))
                print ("%s : %s " % (reg, hex((2**64-1) & reil_context[reg])))
                match = False
                break

        return match

    def __print_contexts(self, context_init, x86_context, reil_context):
        header_fmt = "{0:^8s} : {1:>16s} ?= {2:<16s}"
        header = header_fmt.format("Register", "x86", "REIL")
        ruler  = "-" * len(header)

        print(header)
        print(ruler)

        for reg in sorted(context_init.keys()):
            if  ((2**64-1) & x86_context[reg]) != ((2**64-1) & reil_context[reg]):
                eq = "!="
                marker = "<"
            else:
                eq = "=="
                marker = ""

            fmt = "{0:>8s} : {1:016x} {eq} {2:016x} ({1:>5d} {eq} {2:<5d}) {marker}"

            print fmt.format(
                reg,
                (2**64-1) & x86_context[reg],
                (2**64-1) & reil_context[reg],
                eq=eq,
                marker=marker
            )

    def __update_rflags(self, reil_context_out, x86_context_out):

        reil_context = dict((reg, value) for reg, value in reil_context_out.items() if reg in x86_context_out.keys())

        reil_context['rflags'] = 0xffffffff & (
            0x0                      << 31 | # Reserved
            0x0                      << 30 | # Reserved
            0x0                      << 29 | # Reserved
            0x0                      << 28 | # Reserved
            0x0                      << 27 | # Reserved
            0x0                      << 26 | # Reserved
            0x0                      << 25 | # Reserved
            0x0                      << 24 | # Reserved
            0x0                      << 23 | # Reserved
            0x0                      << 22 | # Reserved
            0x0                      << 21 | # ID
            0x0                      << 20 | # VIP
            0x0                      << 19 | # VIF
            0x0                      << 18 | # AC
            0x0                      << 17 | # VM
            0x0                      << 16 | # RF
            0x0                      << 15 | # Reserved
            0x0                      << 14 | # NT
            0x0                      << 13 | # IOPL
            0x0                      << 12 | # IOPL
            reil_context_out['of']   << 11 | # OF
            reil_context_out['df']   << 10 | # DF
            0x1                      <<  9 | # IF
            0x0                      <<  8 | # TF
            reil_context_out['sf']   <<  7 | # SF
            reil_context_out['zf']   <<  6 | # ZF
            0x0                      <<  5 | # Reserved
            # reil_context_out['af'] <<  4 | # AF
            (x86_context_out['rflags'] & 0x10) | # AF
            0x0                      <<  3 | # Reserved
            # reil_context_out['pf'] <<  2 | # PF
            (x86_context_out['rflags'] & 0x4)  | # PF
            0x1                      <<  1 | # Reserved
            reil_context_out['cf']   <<  0   # CF
        )

        return reil_context

    def __update_flags_from_rflags(self, reil_context):
        reil_context_out = dict(reil_context)

        flags_reg = None

        if 'rflags' in reil_context_out:
            flags_reg = 'rflags'

        if 'eflags' in reil_context_out:
            flags_reg = 'eflags'

        if flags_reg:
            reil_context_out['of'] = reil_context_out[flags_reg] & 2**11 # OF
            reil_context_out['df'] = reil_context_out[flags_reg] & 2**10 # DF
            reil_context_out['sf'] = reil_context_out[flags_reg] & 2**7  # SF
            reil_context_out['zf'] = reil_context_out[flags_reg] & 2**6  # ZF
            reil_context_out['af'] = reil_context_out[flags_reg] & 2**4  # AF
            reil_context_out['pf'] = reil_context_out[flags_reg] & 2**2  # PF
            reil_context_out['cf'] = reil_context_out[flags_reg] & 2**0  # CF

        return reil_context_out
Example #21
0
class SymExecResult(object):

    def __init__(self, arch, initial_state, path, final_state):
        self.__initial_state = initial_state
        self.__path = path
        self.__final_state = final_state

        self.__arch = arch

        self.__smt_solver = None
        self.__smt_translator = None

        self.__code_analyzer = None

        self.__initialize_analyzer()

        self.__setup_solver()

    def query_register(self, register):
        # TODO: This method should return an iterator.

        smt_expr = self.__code_analyzer.get_register_expr(register, mode="pre")
        value = self.__code_analyzer.get_expr_value(smt_expr)

        return value

    def query_memory(self, address, size):
        # TODO: This method should return an iterator.

        smt_expr = self.__code_analyzer.get_memory_expr(address, size, mode="pre")
        value = self.__code_analyzer.get_expr_value(smt_expr)

        return value

    # Auxiliary methods
    # ======================================================================== #
    def __initialize_analyzer(self):
        self.__smt_solver = Z3Solver()

        self.__smt_translator = SmtTranslator(self.__smt_solver, self.__arch.address_size)
        self.__smt_translator.set_arch_alias_mapper(self.__arch.alias_mapper)
        self.__smt_translator.set_arch_registers_size(self.__arch.registers_size)

        self.__code_analyzer = CodeAnalyzer(self.__smt_solver, self.__smt_translator, self.__arch)

    def __setup_solver(self):
        self.__set_initial_state(self.__initial_state)
        self.__add_trace_to_solver(self.__path)
        self.__set_final_state(self.__final_state)

        assert self.__code_analyzer.check() == "sat"

    def __set_initial_state(self, initial_state):
        # Set registers
        for reg, val in initial_state.get_registers().items():
            smt_expr = self.__code_analyzer.get_register_expr(reg, mode="pre")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set memory
        for addr, val in initial_state.get_memory().items():
            smt_expr = self.__code_analyzer.get_memory_expr(addr, 1, mode="pre")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set constraints
        for constr in initial_state.get_constraints():
            self.__code_analyzer.add_constraint(constr)

    def __set_final_state(self, final_state):
        # Set registers
        for reg, val in final_state.get_registers().items():
            smt_expr = self.__code_analyzer.get_register_expr(reg, mode="post")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set memory
        for addr, val in final_state.get_memory().items():
            smt_expr = self.__code_analyzer.get_memory_expr(addr, 1, mode="post")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set constraints
        for constr in final_state.get_constraints():
            self.__code_analyzer.add_constraint(constr)

    def __add_trace_to_solver(self, trace):
        for reil_instr, branch_taken in trace:
            if reil_instr.mnemonic == ReilMnemonic.JCC and isinstance(reil_instr.operands[0], ReilRegisterOperand):
                oprnd = reil_instr.operands[0]
                oprnd_expr = self.__code_analyzer.get_operand_expr(oprnd)

                branch_expr = oprnd_expr != 0x0 if branch_taken else oprnd_expr == 0x0

                # logger.debug("    Branch: {:#010x}:{:02x}  {:s} ({}) - {:s}".format(reil_instr.address >> 8, reil_instr.address & 0xff, reil_instr, branch_taken, branch_expr))

                self.__code_analyzer.add_constraint(branch_expr)
            else:
                self.__code_analyzer.add_instruction(reil_instr)
Example #22
0
class ArmGadgetClassifierTests(unittest.TestCase):
    def setUp(self):

        self._arch_info = ArmArchitectureInformation(ARCH_ARM_MODE_32)
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver,
                                             self._arch_info.address_size)
        self._ir_emulator = ReilEmulator(self._arch_info.address_size)

        self._ir_emulator.set_arch_registers(self._arch_info.registers_gp_all)
        self._ir_emulator.set_arch_registers_size(
            self._arch_info.registers_size)
        self._ir_emulator.set_reg_access_mapper(self._arch_info.alias_mapper)

        self._smt_translator.set_reg_access_mapper(
            self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(
            self._arch_info.registers_size)

        self._code_analyzer = CodeAnalyzer(self._smt_solver,
                                           self._smt_translator)

        self._g_classifier = GadgetClassifier(self._ir_emulator,
                                              self._arch_info)
        self._g_verifier = GadgetVerifier(self._code_analyzer, self._arch_info)

    def _find_and_classify_gadgets(self, binary):
        g_finder = GadgetFinder(
            ArmDisassembler(), binary,
            ArmTranslator(translation_mode=LITE_TRANSLATION), ARCH_ARM,
            ARCH_ARM_MODE_32)

        g_candidates = g_finder.find(0x00000000, len(binary), instrs_depth=4)
        g_classified = self._g_classifier.classify(g_candidates[0])

        #         Debug:
        #         self._print_candidates(g_candidates)
        #         self._print_classified(g_classified)

        return g_candidates, g_classified

    def test_move_register_1(self):
        # testing : dst_reg <- src_reg
        binary = "\x04\x00\xa0\xe1"  # 0x00 : (4)  mov    r0, r4
        binary += "\x31\xff\x2f\xe1"  # 0x04 : (4)  blx    r1

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[0].type, GadgetType.MoveRegister)
        self.assertEquals(g_classified[0].sources,
                          [ReilRegisterOperand("r4", 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r0", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 1)
        self.assertTrue(
            ReilRegisterOperand("r14", 32) in
            g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_move_register_2(self):
        # testing : dst_reg <- src_reg
        binary = "\x00\x00\x84\xe2"  # 0x00 : (4)  add    r0, r4, #0
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.MoveRegister)
        self.assertEquals(g_classified[0].sources,
                          [ReilRegisterOperand("r4", 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r0", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    # TODO: test_move_register_n: mul r0, r4, #1

    def test_load_constant_1(self):
        # testing : dst_reg <- constant
        binary = "\x0a\x20\xa0\xe3"  # 0x00 : (4)  mov    r2, #10
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources,
                          [ReilImmediateOperand(10, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(
            ReilRegisterOperand("r2", 32) in
            g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_2(self):
        # testing : dst_reg <- constant
        binary = "\x02\x20\x42\xe0"  # 0x00 : (4)  sub    r2, r2, r2
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources,
                          [ReilImmediateOperand(0, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(
            ReilRegisterOperand("r2", 32) in
            g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_3(self):
        # testing : dst_reg <- constant
        binary = "\x02\x20\x22\xe0"  # 0x00 : (4)  eor    r2, r2, r2
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources,
                          [ReilImmediateOperand(0, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(
            ReilRegisterOperand("r2", 32) in
            g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_4(self):
        # testing : dst_reg <- constant
        binary = "\x00\x20\x02\xe2"  # 0x00 : (4)  and    r2, r2, #0
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources,
                          [ReilImmediateOperand(0, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(
            ReilRegisterOperand("r2", 32) in
            g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_5(self):
        # testing : dst_reg <- constant
        binary = "\x00\x20\x02\xe2"  # and    r2, r2, #0
        binary += "\x21\x20\x82\xe3"  # orr    r2, r2, #33
        binary += "\x1e\xff\x2f\xe1"  # bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources,
                          [ReilImmediateOperand(33, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(
            ReilRegisterOperand("r2", 32) in
            g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_add_1(self):
        # testing : dst_reg <- src1_reg + src2_reg
        binary = "\x08\x00\x84\xe0"  # 0x00 : (4)  add    r0, r4, r8
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.Arithmetic)
        self.assertEquals(
            g_classified[0].sources,
            [ReilRegisterOperand("r4", 32),
             ReilRegisterOperand("r8", 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[0].operation, "+")

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_sub_1(self):
        # testing : dst_reg <- src1_reg + src2_reg
        binary = "\x08\x00\x44\xe0"  # 0x00 : (4)  sub    r0, r4, r8
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.Arithmetic)
        self.assertEquals(
            g_classified[0].sources,
            [ReilRegisterOperand("r4", 32),
             ReilRegisterOperand("r8", 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[0].operation, "-")

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_memory_1(self):
        # testing : dst_reg <- m[src_reg]
        binary = "\x00\x30\x94\xe5"  # 0x00 : (4)  ldr    r3, [r4]
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadMemory)
        self.assertEquals(
            g_classified[0].sources,
            [ReilRegisterOperand("r4", 32),
             ReilImmediateOperand(0x0, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r3", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_memory_2(self):
        # testing : dst_reg <- m[src_reg + offset]
        binary = "\x33\x30\x94\xe5"  # 0x00 : (4)  ldr    r3, [r4 + 0x33]
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadMemory)
        self.assertEquals(
            g_classified[0].sources,
            [ReilRegisterOperand("r4", 32),
             ReilImmediateOperand(0x33, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r3", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    # TODO: ARM's ldr rd, [rn, r2] is not a valid classification right now

    def test_store_memory_1(self):
        # testing : dst_reg <- m[src_reg]
        binary = "\x00\x30\x84\xe5"  # 0x00 : (4)  str    r3, [r4]
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.StoreMemory)
        self.assertEquals(g_classified[0].sources,
                          [ReilRegisterOperand("r3", 32)])
        self.assertEquals(
            g_classified[0].destination,
            [ReilRegisterOperand("r4", 32),
             ReilImmediateOperand(0x0, 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_store_memory_2(self):
        # testing : dst_reg <- m[src_reg + offset]
        binary = "\x33\x30\x84\xe5"  # 0x00 : (4)  str    r3, [r4 + 0x33]
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.StoreMemory)
        self.assertEquals(g_classified[0].sources,
                          [ReilRegisterOperand("r3", 32)])
        self.assertEquals(
            g_classified[0].destination,
            [ReilRegisterOperand("r4", 32),
             ReilImmediateOperand(0x33, 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_load_add_1(self):
        # testing : dst_reg <- dst_reg + mem[src_reg]
        binary = "\x00\x30\x94\xe5"  # 0x00 : (4)  ldr    r3, [r4]
        binary += "\x03\x00\x80\xe0"  # 0x00 : (4)  add    r0, r0, r3
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticLoad)
        self.assertEquals(g_classified[1].sources, [
            ReilRegisterOperand("r0", 32),
            ReilRegisterOperand("r4", 32),
            ReilImmediateOperand(0x0, 32)
        ])
        self.assertEquals(g_classified[1].destination,
                          [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(
            ReilRegisterOperand("r0", 32) in
            g_classified[1].modified_registers)
        self.assertTrue(
            ReilRegisterOperand("r3", 32) in
            g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def test_arithmetic_load_add_2(self):
        # testing : dst_reg <- dst_reg + mem[src_reg + offset]
        binary = "\x22\x30\x94\xe5"  # 0x00 : (4)  ldr    r3, [r4, 0x22]
        binary += "\x03\x00\x80\xe0"  # 0x00 : (4)  add    r0, r0, r3
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticLoad)
        self.assertEquals(g_classified[1].sources, [
            ReilRegisterOperand("r0", 32),
            ReilRegisterOperand("r4", 32),
            ReilImmediateOperand(0x22, 32)
        ])
        self.assertEquals(g_classified[1].destination,
                          [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(
            ReilRegisterOperand("r0", 32) in
            g_classified[1].modified_registers)
        self.assertTrue(
            ReilRegisterOperand("r3", 32) in
            g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def test_arithmetic_store_add_1(self):
        # testing : m[dst_reg] <- m[dst_reg] + src_reg
        binary = "\x00\x30\x94\xe5"  # 0x00 : (4)  ldr    r3, [r4]
        binary += "\x03\x30\x80\xe0"  # 0x00 : (4)  add    r3, r0, r3
        binary += "\x00\x30\x84\xe5"  # 0x00 : (4)  str    r3, [r4]
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticStore)
        self.assertEquals(g_classified[1].sources, [
            ReilRegisterOperand("r4", 32),
            ReilImmediateOperand(0x0, 32),
            ReilRegisterOperand("r0", 32)
        ])
        self.assertEquals(
            g_classified[1].destination,
            [ReilRegisterOperand("r4", 32),
             ReilImmediateOperand(0x0, 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(
            ReilRegisterOperand("r4", 32) in
            g_classified[1].modified_registers)
        self.assertTrue(
            ReilRegisterOperand("r3", 32) in
            g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def test_arithmetic_store_add_2(self):
        # testing : dst_reg <- dst_reg + mem[src_reg + offset]
        binary = "\x22\x30\x94\xe5"  # 0x00 : (4)  ldr    r3, [r4, 0x22]
        binary += "\x03\x30\x80\xe0"  # 0x00 : (4)  add    r3, r0, r3
        binary += "\x22\x30\x84\xe5"  # 0x00 : (4)  str    r3, [r4, 0x22]
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticStore)
        self.assertEquals(g_classified[1].sources, [
            ReilRegisterOperand("r4", 32),
            ReilImmediateOperand(0x22, 32),
            ReilRegisterOperand("r0", 32)
        ])
        self.assertEquals(
            g_classified[1].destination,
            [ReilRegisterOperand("r4", 32),
             ReilImmediateOperand(0x22, 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(
            ReilRegisterOperand("r4", 32) in
            g_classified[1].modified_registers)
        self.assertTrue(
            ReilRegisterOperand("r3", 32) in
            g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def _print_candidates(self, candidates):
        print "Candidates :"

        for gadget in candidates:
            print gadget
            print "-" * 10

    def _print_classified(self, classified):
        print "Classified :"

        for gadget in classified:
            print gadget
            print gadget.type
            print "-" * 10
class CodeAnalyzerTests(unittest.TestCase):
    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)
        self._operand_size = self._arch_info.operand_size
        self._memory = Memory()
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver,
                                             self._operand_size)
        self._smt_translator.set_arch_alias_mapper(
            self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(
            self._arch_info.registers_size)
        self._disasm = X86Disassembler()
        self._ir_translator = X86Translator()
        self._bb_builder = CFGRecoverer(
            RecursiveDescent(self._disasm, self._memory, self._ir_translator,
                             self._arch_info))

    def test_check_path_satisfiability(self):
        if VERBOSE:
            print "[+] Test: test_check_path_satisfiability"

        # binary : stack1
        bin_start_address, bin_end_address = 0x08048ec0, 0x8048f02

        binary = "\x55"  # 0x08048ec0 : push   ebp
        binary += "\x89\xe5"  # 0x08048ec1 : mov    ebp,esp
        binary += "\x83\xec\x60"  # 0x08048ec3 : sub    esp,0x60
        binary += "\x8d\x45\xfc"  # 0x08048ec6 : lea    eax,[ebp-0x4]
        binary += "\x89\x44\x24\x08"  # 0x08048ec9 : mov    DWORD PTR [esp+0x8],eax
        binary += "\x8d\x45\xac"  # 0x08048ecd : lea    eax,[ebp-0x54]
        binary += "\x89\x44\x24\x04"  # 0x08048ed0 : mov    DWORD PTR [esp+0x4],eax
        binary += "\xc7\x04\x24\xa8\x5a\x0c\x08"  # 0x08048ed4 : mov    DWORD PTR [esp],0x80c5aa8
        binary += "\xe8\xa0\x0a\x00\x00"  # 0x08048edb : call   8049980 <_IO_printf>
        binary += "\x8d\x45\xac"  # 0x08048ee0 : lea    eax,[ebp-0x54]
        binary += "\x89\x04\x24"  # 0x08048ee3 : mov    DWORD PTR [esp],eax
        binary += "\xe8\xc5\x0a\x00\x00"  # 0x08048ee6 : call   80499b0 <_IO_gets>
        binary += "\x8b\x45\xfc"  # 0x08048eeb : mov    eax,DWORD PTR [ebp-0x4]
        binary += "\x3d\x44\x43\x42\x41"  # 0x08048eee : cmp    eax,0x41424344
        binary += "\x75\x0c"  # 0x08048ef3 : jne    8048f01 <main+0x41>
        binary += "\xc7\x04\x24\xc0\x5a\x0c\x08"  # 0x08048ef5 : mov    DWORD PTR [esp],0x80c5ac0
        binary += "\xe8\x4f\x0c\x00\x00"  # 0x08048efc : call   8049b50 <_IO_puts>
        binary += "\xc9"  # 0x08048f01 : leave
        binary += "\xc3"  # 0x08048f02 : ret

        self._memory.add_vma(bin_start_address, bytearray(binary))

        start = 0x08048ec0
        # start = 0x08048ec6
        # end = 0x08048efc
        end = 0x08048f01

        registers = {
            "eax": GenericRegister("eax", 32, 0xffffd0ec),
            "ecx": GenericRegister("ecx", 32, 0x00000001),
            "edx": GenericRegister("edx", 32, 0xffffd0e4),
            "ebx": GenericRegister("ebx", 32, 0x00000000),
            "esp": GenericRegister("esp", 32, 0xffffd05c),
            "ebp": GenericRegister("ebp", 32, 0x08049580),
            "esi": GenericRegister("esi", 32, 0x00000000),
            "edi": GenericRegister("edi", 32, 0x08049620),
            "eip": GenericRegister("eip", 32, 0x08048ec0),
        }

        flags = {
            "af": GenericFlag("af", 0x0),
            "cf": GenericFlag("cf", 0x0),
            "of": GenericFlag("of", 0x0),
            "pf": GenericFlag("pf", 0x1),
            "sf": GenericFlag("sf", 0x0),
            "zf": GenericFlag("zf", 0x1),
        }

        memory = {}

        bb_list, calls = self._bb_builder.build(bin_start_address,
                                                bin_end_address)
        bb_graph = ControlFlowGraph(bb_list)

        codeAnalyzer = CodeAnalyzer(self._smt_solver, self._smt_translator,
                                    self._arch_info)

        codeAnalyzer.set_context(GenericContext(registers, flags, memory))

        for bb_path in bb_graph.all_simple_bb_paths(start, end):
            if VERBOSE:
                print "[+] Checking path satisfiability :"
                print "      From : %s" % hex(start)
                print "      To : %s" % hex(end)
                print "      Path : %s" % " -> ".join(
                    (map(lambda o: hex(o.address), bb_path)))

            is_sat = codeAnalyzer.check_path_satisfiability(bb_path,
                                                            start,
                                                            verbose=False)

            if VERBOSE:
                print "[+] Satisfiability : %s" % str(is_sat)

            self.assertTrue(is_sat)

            if is_sat and VERBOSE:
                print codeAnalyzer.get_context()

            if VERBOSE:
                print ":" * 80
                print ""
Example #24
0
class CodeAnalyzerTests(unittest.TestCase):
    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._smt_solver = SmtSolver()

        self._smt_translator = SmtTranslator(self._smt_solver,
                                             self._arch_info.address_size)
        self._smt_translator.set_arch_alias_mapper(
            self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(
            self._arch_info.registers_size)

        self._x86_parser = X86Parser(ARCH_X86_MODE_32)

        self._x86_translator = X86Translator(ARCH_X86_MODE_32)

        self._code_analyzer = CodeAnalyzer(self._smt_solver,
                                           self._smt_translator,
                                           self._arch_info)

    def test_add_reg_reg(self):
        # Parser x86 instructions.
        asm_instrs = [self._x86_parser.parse(i) for i in [
            "add eax, ebx",
        ]]

        # Add REIL instruction to the analyzer.
        for reil_instr in self.__asm_to_reil(asm_instrs):
            self._code_analyzer.add_instruction(reil_instr)

        # Add constraints.
        eax_pre = self._code_analyzer.get_register_expr("eax", mode="pre")
        eax_post = self._code_analyzer.get_register_expr("eax", mode="post")

        constraints = [
            eax_pre != 42,  # Pre-condition
            eax_post == 42,  # Post-condition
        ]

        for constr in constraints:
            self._code_analyzer.add_constraint(constr)

        # Assertions.
        self.assertEqual(self._code_analyzer.check(), 'sat')
        self.assertNotEqual(self._code_analyzer.get_expr_value(eax_pre), 42)
        self.assertEqual(self._code_analyzer.get_expr_value(eax_post), 42)

    def test_add_reg_mem(self):
        # Parser x86 instructions.
        asm_instrs = [
            self._x86_parser.parse(i) for i in [
                "add eax, [ebx + 0x1000]",
            ]
        ]

        # Add REIL instruction to the analyzer.
        for reil_instr in self.__asm_to_reil(asm_instrs):
            self._code_analyzer.add_instruction(reil_instr)

        # Add constraints.
        eax_pre = self._code_analyzer.get_register_expr("eax", mode="pre")
        eax_post = self._code_analyzer.get_register_expr("eax", mode="post")

        constraints = [
            eax_pre != 42,  # Pre-condition
            eax_post == 42,  # Post-condition
        ]

        for constr in constraints:
            self._code_analyzer.add_constraint(constr)

        # Assertions
        self.assertEqual(self._code_analyzer.check(), 'sat')
        self.assertNotEqual(self._code_analyzer.get_expr_value(eax_pre), 42)
        self.assertEqual(self._code_analyzer.get_expr_value(eax_post), 42)

    def test_add_mem_reg(self):
        # Parser x86 instructions.
        asm_instrs = [
            self._x86_parser.parse(i) for i in [
                "add [eax + 0x1000], ebx",
            ]
        ]

        # Add REIL instruction to the analyzer.
        for reil_instr in self.__asm_to_reil(asm_instrs):
            self._code_analyzer.add_instruction(reil_instr)

        # Add constraints.
        eax_pre = self._code_analyzer.get_register_expr("eax", mode="pre")

        mem_pre = self._code_analyzer.get_memory(mode="pre")
        mem_post = self._code_analyzer.get_memory(mode="post")

        constraints = [
            mem_pre[eax_pre + 0x1000] != 42,  # Pre-condition
            mem_post[eax_pre + 0x1000] == 42,  # Post-condition
        ]

        for constr in constraints:
            self._code_analyzer.add_constraint(constr)

        # Assertions.
        self.assertEqual(self._code_analyzer.check(), 'sat')
        self.assertNotEqual(
            self._code_analyzer.get_expr_value(mem_pre[eax_pre + 0x1000]), 42)
        self.assertEqual(
            self._code_analyzer.get_expr_value(mem_post[eax_pre + 0x1000]), 42)

    def __asm_to_reil(self, instructions):
        # Set address for each instruction.
        for addr, asm_instr in enumerate(instructions):
            asm_instr.address = addr

        # Translate to REIL instructions.
        reil_instrs = [self._x86_translator.translate(i) for i in instructions]

        # Flatten list and return
        reil_instrs = [instr for instrs in reil_instrs for instr in instrs]

        return reil_instrs
Example #25
0
class ReilSymbolicEmulator(object):

    def __init__(self, arch):
        self.__arch = arch

        self.__memory = ReilMemoryEx(self.__arch.address_size)

        self.__tainter = ReilEmulatorTainter(self, arch=self.__arch)

        self.__emulator = ReilEmulator(self.__arch)

        self.__cpu = ReilCpu(self.__memory, arch=self.__arch)

        self.__smt_solver = None
        self.__smt_translator = None

        self.__code_analyzer = None

        self.__initialize_analyzer()

    def find_address(self, container, start=None, end=None, find=None, avoid=None, initial_state=None):
        """Execute instructions.
        """
        # Set initial CPU state.
        self.__set_cpu_state(initial_state)

        # Convert input native addresses to reil addresses.
        start = to_reil_address(start) if start else None
        end = to_reil_address(end) if end else None
        find = to_reil_address(find) if find else None
        avoid = [to_reil_address(addr) for addr in avoid] if avoid else []

        # Load instruction pointer.
        ip = start if start else container[0].address

        execution_state = Queue()

        trace_current = []
        trace_final = []

        self.__fa_process_container(container, find, ip, end, avoid, initial_state, execution_state, trace_current,
                                    trace_final)

        # Only returns when all paths have been visited.
        assert execution_state.empty()

        return trace_final

    def find_state(self, container, start=None, end=None, avoid=None, initial_state=None, final_state=None):
        """Execute instructions.
        """
        self.__set_cpu_state(initial_state)

        # Convert input native addresses to reil addresses.
        start = to_reil_address(start) if start else None
        end = to_reil_address(end) if end else None
        avoid = [to_reil_address(addr) for addr in avoid] if avoid else []

        # Load instruction pointer.
        ip = start if start else container[0].address

        execution_state = Queue()

        trace_current = []
        trace_final = []

        self.__fs_process_container(container, final_state, ip, end, avoid, initial_state, execution_state,
                                    trace_current, trace_final)

        # Only returns when all paths have been visited.
        assert execution_state.empty()

        return trace_final

    # Read/Write methods
    # ======================================================================== #
    def read_operand(self, operand):
        return self.__cpu.read_operand(operand)

    def write_operand(self, operand, value):
        self.__cpu.write_operand(operand, value)

    def read_memory(self, address, size):
        return self.__memory.read(address, size)

    def write_memory(self, address, size, value):
        self.__memory.write(address, size, value)

    # Auxiliary methods.
    # ======================================================================== #
    def __process_branch_direct(self, trace_current, target_addr, avoid, instr, execution_state, initial_state, taken):
        taken_str = "TAKEN" if taken else "NOT TAKEN"

        if target_addr in avoid:
            logger.debug("[+] Avoiding target address ({:s}) : {:#08x}:{:02x}".format(taken_str, target_addr >> 8, target_addr & 0xff))
        else:
            logger.debug("[+] Checking target satisfiability ({:s}) : {:#08x}:{:02x} -> {:#08x}:{:02x}".format(taken_str, instr.address >> 8, instr.address & 0xff, target_addr >> 8, target_addr & 0xff))

            trace_current += [(instr, taken)]

            self.__reset_solver()
            self.__set_initial_state(initial_state)
            self.__add_trace_to_solver(trace_current)

            is_sat = self.__code_analyzer.check()

            logger.debug("[+] Target satisfiable? : {}".format(is_sat == 'sat'))

            if is_sat == 'sat':
                logger.debug("[+] Enqueueing target address ({:s}) : {:#08x}:{:02x}".format(taken_str, target_addr >> 8, target_addr & 0xff))

                execution_state.put((target_addr, trace_current, copy.deepcopy(self.__cpu.registers), copy.deepcopy(self.__cpu.memory)))

    def __process_branch_cond(self, instr, avoid, initial_state, execution_state, trace_current, not_taken_addr):
        # Direct branch (for example: JCC cond, empty, 0x12345678:00)
        if isinstance(instr.operands[2], ReilImmediateOperand):
            # TAKEN
            # ======================================================== #
            trace_current_bck = list(trace_current)

            target_addr = instr.operands[2].immediate

            self.__process_branch_direct(trace_current, target_addr, avoid, instr, execution_state, initial_state,
                                         True)

            # NOT TAKEN
            # ======================================================== #
            trace_current = trace_current_bck

            target_addr = not_taken_addr

            self.__process_branch_direct(trace_current, target_addr, avoid, instr, execution_state, initial_state,
                                         False)
            # ======================================================== #

            next_ip = None

        # Indirect branch (for example: JCC cond, empty, target)
        else:
            raise Exception("Indirect jump not supported yet.")

        return next_ip

    def __process_branch_uncond(self, instr, trace_current, not_taken_addr):
        # TAKEN branch (for example: JCC 0x1, empty, oprnd2).
        if instr.operands[0].immediate != 0x0:
            # Direct branch (for example: JCC 0x1, empty, INTEGER)
            if isinstance(instr.operands[2], ReilImmediateOperand):
                trace_current += [(instr, None)]

                next_ip = self.__cpu.execute(instr)

            # Indirect branch (for example: JCC 0x1, empty, REGISTER)
            else:
                raise Exception("Indirect jump not supported yet.")

        # NOT TAKEN branch (for example: JCC 0x0, empty, oprnd2).
        else:
            next_ip = not_taken_addr

        return next_ip

    def __process_instr(self, instr, avoid, next_addr, initial_state, execution_state, trace_current):
        """Process a REIL instruction.

        Args:
            instr (ReilInstruction): Instruction to process.
            avoid (list): List of addresses to avoid while executing the code.
            next_addr (int): Address of the following instruction.
            initial_state (State): Initial execution state.
            execution_state (Queue): Queue of execution states.
            trace_current (list): Current trace.

        Returns:
            int: Returns the next address to execute.
        """
        # Process branch (JCC oprnd0, empty, oprnd2).
        if instr.mnemonic == ReilMnemonic.JCC:
            not_taken_addr = next_addr
            address, index = split_address(instr.address)

            logger.debug("[+] Processing branch: {:#08x}:{:02x} : {}".format(address, index, instr))

            # Process conditional branch (oprnd0 is a REGISTER).
            if isinstance(instr.operands[0], ReilRegisterOperand):
                next_ip = self.__process_branch_cond(instr, avoid, initial_state, execution_state, trace_current, not_taken_addr)

            # Process unconditional branch (oprnd0 is an INTEGER).
            else:
                next_ip = self.__process_branch_uncond(instr, trace_current, not_taken_addr)

        # Process the rest of the instructions.
        else:
            trace_current += [(instr, None)]

            self.__cpu.execute(instr)

            next_ip = next_addr

        return next_ip

    # find_state's auxiliary methods
    # ======================================================================== #
    def __fs_process_container(self, container, final_state, start, end, avoid, initial_state, execution_state,
                               trace_current, trace_final):
        ip = start

        while ip:
            # Fetch next instruction.
            try:
                instr = container.fetch(ip)
            except ReilContainerInvalidAddressError:
                logger.debug("Exception @ {:#08x}".format(ip))

                raise ReilContainerInvalidAddressError

            # Compute the address of the following instruction to the fetched one.
            try:
                next_addr = container.get_next_address(ip)
            except Exception:
                logger.debug("Exception @ {:#08x}".format(ip))

                # TODO Should this be considered an error?

                raise ReilContainerInvalidAddressError

            # Process instruction
            next_ip = self.__process_instr(instr, avoid, next_addr, initial_state, execution_state, trace_current)

            # Check is final state holds.
            if instr.mnemonic == ReilMnemonic.JCC and isinstance(instr.operands[0], ReilRegisterOperand):
                # TODO Check only when it is necessary.
                self.__reset_solver()
                self.__set_initial_state(initial_state)
                self.__add_trace_to_solver(trace_current)
                self.__set_final_state(final_state)

                is_sat = self.__code_analyzer.check()

                if is_sat == "sat":
                    logger.debug("[+] Final state found!")

                    trace_final.append(list(trace_current))

                    next_ip = None

            # Check termination conditions.
            if end and next_ip and next_ip == end:
                logger.debug("[+] End address found!")

                next_ip = None

            # Update instruction pointer.
            ip = next_ip if next_ip else None

            while not ip:
                if not execution_state.empty():
                    # Pop next execution state.
                    ip, trace_current, registers, memory = execution_state.get()

                    if split_address(ip)[1] == 0x0:
                        logger.debug("[+] Popping execution state @ {:#x} (INTER)".format(ip))
                    else:
                        logger.debug("[+] Popping execution state @ {:#x} (INTRA)".format(ip))

                    # Setup cpu and memory.
                    self.__cpu.registers = registers
                    self.__cpu.memory = memory

                    logger.debug("[+] Next address: {:#08x}:{:02x}".format(ip >> 8, ip & 0xff))
                else:
                    logger.debug("[+] No more paths to explore! Exiting...")
                    break

                # Check termination conditions (AGAIN...).
                if end and ip == end:
                    logger.debug("[+] End address found!")

                    ip = None

    # find_address's auxiliary methods
    # ======================================================================== #
    def __fa_process_sequence(self, sequence, avoid, initial_state, execution_state, trace_current, next_addr):
        """Process a REIL sequence.

        Args:
            sequence (ReilSequence): A REIL sequence to process.
            avoid (list): List of address to avoid.
            initial_state: Initial state.
            execution_state: Execution state queue.
            trace_current (list): Current trace.
            next_addr: Address of the next instruction following the current one.

        Returns:
            Returns the next instruction to execute in case there is one, otherwise returns None.
        """
        # TODO: Process execution intra states.

        ip = sequence.address
        next_ip = None

        while ip:
            # Fetch next instruction in the sequence.
            try:
                instr = sequence.fetch(ip)
            except ReilSequenceInvalidAddressError:
                # At this point, ip should be a native instruction address, therefore
                # the index should be zero.
                assert split_address(ip)[1] == 0x0
                next_ip = ip
                break

            try:
                target_addr = sequence.get_next_address(ip)
            except ReilSequenceInvalidAddressError:
                # We reached the end of the sequence. Execution continues on the next native instruction
                # (it's a REIL address).
                target_addr = next_addr

            next_ip = self.__process_instr(instr, avoid, target_addr, initial_state, execution_state, trace_current)

            # Update instruction pointer.
            try:
                ip = next_ip if next_ip else sequence.get_next_address(ip)
            except ReilSequenceInvalidAddressError:
                break

        return next_ip

    def __fa_process_container(self, container, find, start, end, avoid, initial_state, execution_state, trace_current,
                               trace_final):
        """Process a REIL container.

        Args:
            avoid (list): List of addresses to avoid while executing the code.
            container (ReilContainer): REIL container to execute.
            end (int): End address.
            execution_state (Queue): Queue of execution states.
            find (int): Address to find.
            initial_state (State): Initial state.
            start (int): Start address.
            trace_current:
            trace_final:
        """
        ip = start

        while ip:
            # NOTE *ip* and *next_addr* variables can be, independently, either intra
            # or inter addresses.

            # Fetch next instruction.
            try:
                instr = container.fetch(ip)
            except ReilContainerInvalidAddressError:
                logger.debug("Exception @ {:#08x}".format(ip))

                raise ReilContainerInvalidAddressError

            # Compute the address of the following instruction to the fetched one.
            try:
                next_addr = container.get_next_address(ip)
            except Exception:
                logger.debug("Exception @ {:#08x}".format(ip))

                # TODO Should this be considered an error?

                raise ReilContainerInvalidAddressError

            # Process the instruction.
            next_ip = self.__process_instr(instr, avoid, next_addr, initial_state, execution_state, trace_current)

            # # ====================================================================================================== #
            # # NOTE This is an attempt to separate intra and inter instruction
            # # addresses processing. Here, *ip* and *next_addr* are always inter
            # # instruction addresses.
            #
            # assert split_address(ip)[1] == 0x0
            #
            # # Compute the address of the following instruction to the fetched one.
            # try:
            #     seq = container.fetch_sequence(ip)
            # except ReilContainerInvalidAddressError:
            #     logger.debug("Exception @ {:#08x}".format(ip))
            #
            #     raise ReilContainerInvalidAddressError
            #
            # # Fetch next instruction address.
            # try:
            #     next_addr = container.get_next_address(ip + len(seq))
            # except Exception:
            #     logger.debug("Exception @ {:#08x}".format(ip))
            #
            #     # TODO Should this be considered an error?
            #
            #     raise ReilContainerInvalidAddressError
            #
            # next_ip = self.__process_sequence(seq, avoid, initial_state, execution_state, trace_current, next_addr)
            #
            # if next_ip:
            #     assert split_address(next_ip)[1] == 0x0
            #
            # # ====================================================================================================== #

            # Check termination conditions.
            if find and next_ip and next_ip == find:
                logger.debug("[+] Find address found!")

                trace_final.append(list(trace_current))

                next_ip = None

            if end and next_ip and next_ip == end:
                logger.debug("[+] End address found!")

                next_ip = None

            # Update instruction pointer.
            ip = next_ip if next_ip else None

            while not ip:
                if not execution_state.empty():
                    # Pop next execution state.
                    ip, trace_current, registers, memory = execution_state.get()

                    if split_address(ip)[1] == 0x0:
                        logger.debug("[+] Popping execution state @ {:#x} (INTER)".format(ip))
                    else:
                        logger.debug("[+] Popping execution state @ {:#x} (INTRA)".format(ip))

                    # Setup cpu and memory.
                    self.__cpu.registers = registers
                    self.__cpu.memory = memory

                    logger.debug("[+] Next address: {:#08x}:{:02x}".format(ip >> 8, ip & 0xff))
                else:
                    logger.debug("[+] No more paths to explore! Exiting...")

                    break

                # Check termination conditions (AGAIN...).
                if find and ip == find:
                    logger.debug("[+] Find address found!")

                    trace_final.append(list(trace_current))

                    ip = None

                if end and ip == end:
                    logger.debug("[+] End address found!")

                    ip = None

    # Auxiliary methods
    # ======================================================================== #
    def __initialize_analyzer(self):
        self.__smt_solver = Z3Solver()

        self.__smt_translator = SmtTranslator(self.__smt_solver, self.__arch.address_size)
        self.__smt_translator.set_arch_alias_mapper(self.__arch.alias_mapper)
        self.__smt_translator.set_arch_registers_size(self.__arch.registers_size)

        self.__code_analyzer = CodeAnalyzer(self.__smt_solver, self.__smt_translator, self.__arch)

    def __reset_solver(self):
        self.__code_analyzer.reset()

    def __set_cpu_state(self, state):
        # Set registers
        for reg, val in state.get_registers().items():
            self.__cpu.registers[reg] = val

        # Set memory
        for addr, val in state.get_memory().items():
            self.__cpu.write_memory(addr, 1, val)

    def __set_initial_state(self, initial_state):
        # Set registers
        for reg, val in initial_state.get_registers().items():
            smt_expr = self.__code_analyzer.get_register_expr(reg, mode="pre")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set memory
        for addr, val in initial_state.get_memory().items():
            smt_expr = self.__code_analyzer.get_memory_expr(addr, 1, mode="pre")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set constraints
        for constr in initial_state.get_constraints():
            self.__code_analyzer.add_constraint(constr)

    def __set_final_state(self, final_state):
        # Set registers
        for reg, val in final_state.get_registers().items():
            smt_expr = self.__code_analyzer.get_register_expr(reg, mode="post")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set memory
        for addr, val in final_state.get_memory().items():
            smt_expr = self.__code_analyzer.get_memory_expr(addr, 1, mode="post")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set constraints
        for constr in final_state.get_constraints():
            self.__code_analyzer.add_constraint(constr)

    def __add_trace_to_solver(self, trace):
        for reil_instr, branch_taken in trace:
            if reil_instr.mnemonic == ReilMnemonic.JCC and isinstance(reil_instr.operands[0], ReilRegisterOperand):
                oprnd = reil_instr.operands[0]
                oprnd_expr = self.__code_analyzer.get_operand_expr(oprnd)

                branch_expr = oprnd_expr != 0x0 if branch_taken else oprnd_expr == 0x0

                # logger.debug("    Branch: {:#010x}:{:02x}  {:s} ({}) - {:s}".format(reil_instr.address >> 8, reil_instr.address & 0xff, reil_instr, branch_taken, branch_expr))

                self.__code_analyzer.add_constraint(branch_expr)
            else:
                self.__code_analyzer.add_instruction(reil_instr)
Example #26
0
class SmtTranslatorTests(unittest.TestCase):
    def setUp(self):
        self._address_size = 32
        self._parser = ReilParser()
        self._solver = SmtSolver()
        self._translator = SmtTranslator(self._solver, self._address_size)

    def test_add_reg_reg(self):
        if VERBOSE:
            print "\n[+] Test: test_add_reg_reg"

        # add eax, ebx
        instrs = self._parser.parse([
            "add [eax, ebx, t0]",
            "str [t0, e, eax]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[1].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[2].size = 32

        self._solver.reset()

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            self._solver.mkBitVec(32,
                                  self._translator.get_curr_name("eax")) == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("eax")) !=
            42,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("ebx"))

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_add_reg_mem(self):
        if VERBOSE:
            print "\n[+] Test: test_add_reg_mem"

        # add eax, [ebx]
        instrs = self._parser.parse([
            "ldm [ebx, EMPTY, t0]",
            "add [eax, t0, t1]",
            "str [t1, EMPTY, eax]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[1].size = 32
        instrs[1].operands[2].size = 32

        instrs[2].operands[0].size = 32
        instrs[2].operands[2].size = 32

        self._solver.reset()

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            self._solver.mkBitVec(32,
                                  self._translator.get_curr_name("eax")) == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("eax")) !=
            42,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("ebx"))

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_add_mem_reg(self):
        if VERBOSE:
            print "\n[+] Test: test_add_mem_reg"

        # add [eax], ebx
        instrs = self._parser.parse([
            "ldm [eax, EMPTY, t0]",
            "add [t0, ebx, t1]",
            "stm [t1, EMPTY, eax]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[1].size = 32
        instrs[1].operands[2].size = 32

        instrs[2].operands[0].size = 32
        instrs[2].operands[2].size = 32

        self._solver.reset()

        # add constrains
        mem = self._translator.get_memory()
        eax = self._solver.mkBitVec(32, "eax_0")

        constraint = (mem[eax] != 42)

        if VERBOSE:
            print "constraint : %s" % constraint

        self._solver.add(constraint)

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            mem[eax] == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("t0")) !=
            42,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("ebx"))
                print "    t0 : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("t0"))
                print "    [eax] : 0x%08x" % self._solver.getvalue(mem[eax])

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_add_mem_reg_2(self):
        if VERBOSE:
            print "\n[+] Test: test_add_mem_reg_2"

        # add [eax + 0x1000], ebx
        instrs = self._parser.parse([
            "add [eax, 0x1000, t0]",
            "ldm [t0, e, t1]",
            "add [t1, ebx, t2]",
            "stm [t2, e, t0]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[1].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[2].size = 32

        instrs[2].operands[0].size = 32
        instrs[2].operands[1].size = 32
        instrs[2].operands[2].size = 32

        instrs[3].operands[0].size = 32
        instrs[3].operands[2].size = 32

        self._solver.reset()

        # add constrains
        mem = self._translator.get_memory()
        eax = self._solver.mkBitVec(32, "eax_0")
        off = BitVec(32, "#x%08x" % 0x1000)

        constraint = (mem[eax + off] != 42)

        if VERBOSE:
            print "constraint : %s" % constraint

        self._solver.add(constraint)

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            mem[eax + off] == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("t1")) !=
            42,
        ]

        if VERBOSE:
            print "[+] Constraints :"

            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("ebx"))
                print "    t0 : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("t0"))
                print "    t1 : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("t1"))
                print "    [eax + off] : 0x%08x" % self._solver.getvalue(
                    mem[eax + off])

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_mul(self):
        if VERBOSE:
            print "\n[+] Test: test_mul"

        instrs = self._parser.parse([
            "mul [0x0, 0x1, t0]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[1].size = 32

        # TODO: Ver esto, el tam del output deberia ser 64
        instrs[0].operands[2].size = 32

        self._solver.reset()

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            self._solver.mkBitVec(32,
                                  self._translator.get_curr_name("t0")) == 0,
            self._solver.mkBitVec(32, self._translator.get_init_name("t0")) !=
            0,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    t0 : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("t0"))

            print "~" * 80

        self.assertEqual(is_sat, True)