Exemple #1
0
    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._emulator = ReilEmulator(self._arch_info)

        self._asm_parser = X86Parser()
        self._translator = X86Translator()
    def setUp(self):
        self.trans_mode = FULL_TRANSLATION
        self.arch_mode = ARCH_ARM_MODE_THUMB
        self.arch_info = ArmArchitectureInformation(self.arch_mode)
        self.arm_parser = ArmParser(self.arch_mode)
        self.arm_translator = ArmTranslator(self.arch_mode, self.trans_mode)
        self.reil_emulator = ReilEmulator(self.arch_info)

        self.context_filename = "failing_context.data"
Exemple #3
0
    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._emulator = ReilEmulator(self._arch_info.address_size)

        self._emulator.set_arch_registers(self._arch_info.registers_gp_all)
        self._emulator.set_arch_registers_size(self._arch_info.registers_size)
        self._emulator.set_reg_access_mapper(self._arch_info.alias_mapper)

        self._asm_parser = X86Parser()
        self._translator = X86Translator()
Exemple #4
0
    def __init__ (self, binary):
        self.elf = elffile.ELFFile(binary)

        if self.elf.elfclass == 32:
            self.arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)
        if self.elf.elfclass == 64:
            self.arch_info = X86ArchitectureInformation(ARCH_X86_MODE_64)

        self.emulator = ReilEmulator(self.arch_info.address_size)
        self.emulator.set_arch_registers(self.arch_info.registers_gp)
        self.emulator.set_arch_registers_size(self.arch_info.register_size)
        self.emulator.set_reg_access_mapper(self.arch_info.register_access_mapper())

        self.classifier = GadgetClassifier(self.emulator, self.arch_info)

        self.smt_solver = SmtSolver()
        self.smt_translator = SmtTranslator(self.smt_solver, self.arch_info.address_size)

        self.smt_translator.set_reg_access_mapper(self.arch_info.register_access_mapper())
        self.smt_translator.set_arch_registers_size(self.arch_info.register_size)

        self.code_analyzer = CodeAnalyzer(self.smt_solver, self.smt_translator)

        self.gadgets = {}
        self.classified_gadgets = {}

        self.regset = RegSet(self)
        self.ccf = CCFlag(self)
        self.ams = ArithmeticStore(self)
        self.memstr = MemoryStore(self)

        self.reil_translator = X86Translator(architecture_mode=self.arch_info.architecture_mode,
                                                  translation_mode=FULL_TRANSLATION)
    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._emulator = ReilEmulator(self._arch_info)

        self._asm_parser = X86Parser()
        self._translator = X86Translator()
Exemple #6
0
    def setUp(self):
        self.trans_mode = FULL_TRANSLATION

        self.arch_mode = ARCH_ARM_MODE_32

        self.arch_info = ArmArchitectureInformation(self.arch_mode)

        self.arm_parser = ArmParser(self.arch_mode)
        self.arm_translator = ArmTranslator(self.arch_mode, self.trans_mode)
        self.reil_emulator = ReilEmulator(self.arch_info.address_size)

        self.reil_emulator.set_arch_registers(self.arch_info.registers_gp_all)
        self.reil_emulator.set_arch_registers_size(
            self.arch_info.registers_size)
        self.reil_emulator.set_reg_access_mapper(self.arch_info.alias_mapper)

        self.context_filename = "failing_context.data"
    def setUp(self):
        self.trans_mode = FULL_TRANSLATION
        self.arch_mode = ARCH_ARM_MODE_32
        self.arch_info = ArmArchitectureInformation(self.arch_mode)
        self.arm_parser = ArmParser(self.arch_mode)
        self.arm_translator = ArmTranslator(self.arch_mode, self.trans_mode)
        self.reil_emulator = ReilEmulator(self.arch_info)

        self.context_filename = "failing_context.data"
    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._emulator = ReilEmulator(self._arch_info.address_size)

        self._emulator.set_arch_registers(self._arch_info.registers_gp_all)
        self._emulator.set_arch_registers_size(self._arch_info.registers_size)
        self._emulator.set_reg_access_mapper(self._arch_info.alias_mapper)

        self._asm_parser = X86Parser()
        self._translator = X86Translator()
Exemple #9
0
    def test_emulate_arm(self):
        binary = BinaryFile(get_full_path("./samples/bin/loop-simple.arm"))
        arch_mode = ARCH_ARM_MODE_ARM
        arch_info = ArmArchitectureInformation(arch_mode)
        ir_emulator = ReilEmulator(arch_info)
        disassembler = ArmDisassembler(architecture_mode=ARCH_ARM_MODE_ARM)
        ir_translator = ArmTranslator(architecture_mode=ARCH_ARM_MODE_ARM)

        emu = Emulator(arch_info, ir_emulator, ir_translator, disassembler)

        emu.load_binary(binary)

        emu.emulate(0x10400, 0x10460, {}, None, True)
Exemple #10
0
    def test_emulate_x86_64(self):
        binary = BinaryFile(get_full_path("./samples/bin/loop-simple.x86_64"))
        arch_mode = ARCH_X86_MODE_64
        arch_info = X86ArchitectureInformation(arch_mode)
        ir_emulator = ReilEmulator(arch_info)
        disassembler = X86Disassembler(architecture_mode=ARCH_X86_MODE_64)
        ir_translator = X86Translator(architecture_mode=ARCH_X86_MODE_64)

        emu = Emulator(arch_info, ir_emulator, ir_translator, disassembler)

        emu.load_binary(binary)

        emu.emulate(0x4004d6, 0x400507, {}, None, False)
Exemple #11
0
    def setUp(self):

        self._arch_info = ArmArchitectureInformation(ARCH_ARM_MODE_32)
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver,
                                             self._arch_info.address_size)
        self._ir_emulator = ReilEmulator(self._arch_info.address_size)

        self._ir_emulator.set_arch_registers(self._arch_info.registers_gp_all)
        self._ir_emulator.set_arch_registers_size(
            self._arch_info.registers_size)
        self._ir_emulator.set_reg_access_mapper(self._arch_info.alias_mapper)

        self._smt_translator.set_reg_access_mapper(
            self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(
            self._arch_info.registers_size)

        self._code_analyzer = CodeAnalyzer(self._smt_solver,
                                           self._smt_translator)

        self._g_classifier = GadgetClassifier(self._ir_emulator,
                                              self._arch_info)
        self._g_verifier = GadgetVerifier(self._code_analyzer, self._arch_info)
    def setUp(self):
        self.trans_mode = FULL_TRANSLATION

        self.arch_mode = ARCH_ARM_MODE_32

        self.arch_info = ArmArchitectureInformation(self.arch_mode)

        self.arm_parser = ArmParser(self.arch_mode)
        self.arm_translator = ArmTranslator(self.arch_mode, self.trans_mode)
        self.reil_emulator = ReilEmulator(self.arch_info.address_size)

        self.reil_emulator.set_arch_registers(self.arch_info.registers_gp_all)
        self.reil_emulator.set_arch_registers_size(self.arch_info.registers_size)
        self.reil_emulator.set_reg_access_mapper(self.arch_info.alias_mapper)

        self.context_filename = "failing_context.data"
    def setUp(self):

        self._arch_info = ArmArchitectureInformation(ARCH_ARM_MODE_32)
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver, self._arch_info.address_size)
        self._ir_emulator = ReilEmulator(self._arch_info.address_size)

        self._ir_emulator.set_arch_registers(self._arch_info.registers_gp_all)
        self._ir_emulator.set_arch_registers_size(self._arch_info.registers_size)
        self._ir_emulator.set_reg_access_mapper(self._arch_info.alias_mapper)

        self._smt_translator.set_reg_access_mapper(self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(self._arch_info.registers_size)

        self._code_analyzer = CodeAnalyzer(self._smt_solver, self._smt_translator)

        self._g_classifier = GadgetClassifier(self._ir_emulator, self._arch_info)
        self._g_verifier = GadgetVerifier(self._code_analyzer, self._arch_info)
Exemple #14
0
    def setUp(self):
        self.trans_mode = FULL_TRANSLATION

        self.arch_mode = ARCH_X86_MODE_64

        self.arch_info = X86ArchitectureInformation(self.arch_mode)

        self.x86_parser = X86Parser(self.arch_mode)
        self.x86_translator = X86Translator(self.arch_mode, self.trans_mode)
        self.smt_solver = SmtSolver()
        self.smt_translator = SmtTranslator(self.smt_solver, self.arch_info.address_size)
        self.reil_emulator = ReilEmulator(self.arch_info.address_size)

        self.reil_emulator.set_arch_registers(self.arch_info.registers_gp)
        self.reil_emulator.set_arch_registers_size(self.arch_info.register_size)
        self.reil_emulator.set_reg_access_mapper(self.arch_info.register_access_mapper())

        self.smt_translator.set_reg_access_mapper(self.arch_info.register_access_mapper())
        self.smt_translator.set_arch_registers_size(self.arch_info.register_size)
Exemple #15
0
class ReilEmulatorTests(unittest.TestCase):
    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._emulator = ReilEmulator(self._arch_info)

        self._asm_parser = X86Parser()
        self._reil_parser = ReilParser()

        self._translator = X86Translator()

    def test_add(self):
        asm_instrs = self._asm_parser.parse("add eax, ebx")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax": 0x1,
            "ebx": 0x2,
        }

        regs_final, _ = self._emulator.execute_lite(reil_instrs,
                                                    context=regs_initial)

        self.assertEqual(regs_final["eax"], 0x3)
        self.assertEqual(regs_final["ebx"], 0x2)

    def test_loop(self):
        # 0x08048060 : b8 00 00 00 00   mov eax,0x0
        # 0x08048065 : bb 0a 00 00 00   mov ebx,0xa
        # 0x0804806a : 83 c0 01         add eax,0x1
        # 0x0804806d : 83 eb 01         sub ebx,0x1
        # 0x08048070 : 83 fb 00         cmp ebx,0x0
        # 0x08048073 : 75 f5            jne 0x0804806a

        asm_instrs_str = [(0x08048060, "mov eax,0x0", 5)]
        asm_instrs_str += [(0x08048065, "mov ebx,0xa", 5)]
        asm_instrs_str += [(0x0804806a, "add eax,0x1", 3)]
        asm_instrs_str += [(0x0804806d, "sub ebx,0x1", 3)]
        asm_instrs_str += [(0x08048070, "cmp ebx,0x0", 3)]
        asm_instrs_str += [(0x08048073, "jne 0x0804806a", 2)]

        asm_instrs = []

        for addr, asm, size in asm_instrs_str:
            asm_instr = self._asm_parser.parse(asm)
            asm_instr.address = addr
            asm_instr.size = size

            asm_instrs.append(asm_instr)

        reil_instrs = self.__translate(asm_instrs)

        regs_final, _ = self._emulator.execute(reil_instrs,
                                               start=0x08048060 << 8)

        self.assertEqual(regs_final["eax"], 0xa)
        self.assertEqual(regs_final["ebx"], 0x0)

    def test_mov(self):
        asm_instrs = [self._asm_parser.parse("mov eax, 0xdeadbeef")]
        asm_instrs += [self._asm_parser.parse("mov al, 0x12")]
        asm_instrs += [self._asm_parser.parse("mov ah, 0x34")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs = self._translator.translate(asm_instrs[0])
        reil_instrs += self._translator.translate(asm_instrs[1])
        reil_instrs += self._translator.translate(asm_instrs[2])

        regs_initial = {
            "eax": 0xffffffff,
        }

        regs_final, _ = self._emulator.execute_lite(reil_instrs,
                                                    context=regs_initial)

        self.assertEqual(regs_final["eax"], 0xdead3412)

    def test_pre_hanlder(self):
        def pre_hanlder(emulator, instruction, parameter):
            paramter.append(True)

        asm = ["mov eax, ebx"]

        x86_instrs = map(self._asm_parser.parse, asm)
        self.__set_address(0xdeadbeef, x86_instrs)
        reil_instrs = map(self._translator.translate, x86_instrs)

        paramter = []

        self._emulator.set_instruction_pre_handler(pre_hanlder, paramter)

        reil_ctx_out, reil_mem_out = self._emulator.execute_lite(
            reil_instrs[0])

        self.assertTrue(len(paramter) > 0)

    def test_post_hanlder(self):
        def post_hanlder(emulator, instruction, parameter):
            paramter.append(True)

        asm = ["mov eax, ebx"]

        x86_instrs = map(self._asm_parser.parse, asm)
        self.__set_address(0xdeadbeef, x86_instrs)
        reil_instrs = map(self._translator.translate, x86_instrs)

        paramter = []

        self._emulator.set_instruction_post_handler(post_hanlder, paramter)

        reil_ctx_out, reil_mem_out = self._emulator.execute_lite(
            reil_instrs[0])

        self.assertTrue(len(paramter) > 0)

    def test_zero_division_error_1(self):
        asm_instrs = [self._asm_parser.parse("div ebx")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs = self._translator.translate(asm_instrs[0])

        regs_initial = {
            "eax": 0x2,
            "edx": 0x2,
            "ebx": 0x0,
        }

        self.assertRaises(ReilCpuZeroDivisionError,
                          self._emulator.execute_lite,
                          reil_instrs,
                          context=regs_initial)

    def test_zero_division_error_2(self):
        instrs = ["mod [DWORD eax, DWORD ebx, DWORD t0]"]

        reil_instrs = self._reil_parser.parse(instrs)

        reil_instrs[0].address = 0xdeadbeef00

        regs_initial = {
            "eax": 0x2,
            "ebx": 0x0,
        }

        self.assertRaises(ReilCpuZeroDivisionError,
                          self._emulator.execute_lite,
                          reil_instrs,
                          context=regs_initial)

    def test_invalid_address_error_1(self):
        asm_instrs = [self._asm_parser.parse("jmp eax")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs = self.__translate(asm_instrs)

        regs_initial = {
            "eax": 0xffffffff,
        }

        self.assertRaises(ReilCpuInvalidAddressError,
                          self._emulator.execute,
                          reil_instrs,
                          start=0xdeadbeef << 8,
                          registers=regs_initial)

    def test_invalid_address_error_2(self):
        asm_instrs = [self._asm_parser.parse("mov eax, 0xdeadbeef")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs = self.__translate(asm_instrs)

        regs_initial = {
            "eax": 0xffffffff,
        }

        self.assertRaises(ReilCpuInvalidAddressError,
                          self._emulator.execute,
                          reil_instrs,
                          start=0xdeadbef0 << 8,
                          registers=regs_initial)

    # Auxiliary methods
    # ======================================================================== #
    def __set_address(self, address, asm_instrs):
        addr = address

        for asm_instr in asm_instrs:
            asm_instr.address = addr
            addr += 1

    def __translate(self, asm_instrs):
        instr_container = ReilContainer()

        asm_instr_last = None
        instr_seq_prev = None

        for asm_instr in asm_instrs:
            instr_seq = ReilSequence()

            for reil_instr in self._translator.translate(asm_instr):
                instr_seq.append(reil_instr)

            if instr_seq_prev:
                instr_seq_prev.next_sequence_address = instr_seq.address

            instr_container.add(instr_seq)

            instr_seq_prev = instr_seq

        if instr_seq_prev:
            if asm_instr_last:
                instr_seq_prev.next_sequence_address = (
                    asm_instr_last.address + asm_instr_last.size) << 8

        # instr_container.dump()

        return instr_container
Exemple #16
0
class ArmGadgetClassifierTests(unittest.TestCase):
    def setUp(self):

        self._arch_info = ArmArchitectureInformation(ARCH_ARM_MODE_32)
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver,
                                             self._arch_info.address_size)
        self._ir_emulator = ReilEmulator(self._arch_info.address_size)

        self._ir_emulator.set_arch_registers(self._arch_info.registers_gp_all)
        self._ir_emulator.set_arch_registers_size(
            self._arch_info.registers_size)
        self._ir_emulator.set_reg_access_mapper(self._arch_info.alias_mapper)

        self._smt_translator.set_reg_access_mapper(
            self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(
            self._arch_info.registers_size)

        self._code_analyzer = CodeAnalyzer(self._smt_solver,
                                           self._smt_translator)

        self._g_classifier = GadgetClassifier(self._ir_emulator,
                                              self._arch_info)
        self._g_verifier = GadgetVerifier(self._code_analyzer, self._arch_info)

    def _find_and_classify_gadgets(self, binary):
        g_finder = GadgetFinder(
            ArmDisassembler(), binary,
            ArmTranslator(translation_mode=LITE_TRANSLATION), ARCH_ARM,
            ARCH_ARM_MODE_32)

        g_candidates = g_finder.find(0x00000000, len(binary), instrs_depth=4)
        g_classified = self._g_classifier.classify(g_candidates[0])

        #         Debug:
        #         self._print_candidates(g_candidates)
        #         self._print_classified(g_classified)

        return g_candidates, g_classified

    def test_move_register_1(self):
        # testing : dst_reg <- src_reg
        binary = "\x04\x00\xa0\xe1"  # 0x00 : (4)  mov    r0, r4
        binary += "\x31\xff\x2f\xe1"  # 0x04 : (4)  blx    r1

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[0].type, GadgetType.MoveRegister)
        self.assertEquals(g_classified[0].sources,
                          [ReilRegisterOperand("r4", 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r0", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 1)
        self.assertTrue(
            ReilRegisterOperand("r14", 32) in
            g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_move_register_2(self):
        # testing : dst_reg <- src_reg
        binary = "\x00\x00\x84\xe2"  # 0x00 : (4)  add    r0, r4, #0
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.MoveRegister)
        self.assertEquals(g_classified[0].sources,
                          [ReilRegisterOperand("r4", 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r0", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    # TODO: test_move_register_n: mul r0, r4, #1

    def test_load_constant_1(self):
        # testing : dst_reg <- constant
        binary = "\x0a\x20\xa0\xe3"  # 0x00 : (4)  mov    r2, #10
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources,
                          [ReilImmediateOperand(10, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(
            ReilRegisterOperand("r2", 32) in
            g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_2(self):
        # testing : dst_reg <- constant
        binary = "\x02\x20\x42\xe0"  # 0x00 : (4)  sub    r2, r2, r2
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources,
                          [ReilImmediateOperand(0, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(
            ReilRegisterOperand("r2", 32) in
            g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_3(self):
        # testing : dst_reg <- constant
        binary = "\x02\x20\x22\xe0"  # 0x00 : (4)  eor    r2, r2, r2
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources,
                          [ReilImmediateOperand(0, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(
            ReilRegisterOperand("r2", 32) in
            g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_4(self):
        # testing : dst_reg <- constant
        binary = "\x00\x20\x02\xe2"  # 0x00 : (4)  and    r2, r2, #0
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources,
                          [ReilImmediateOperand(0, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(
            ReilRegisterOperand("r2", 32) in
            g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_5(self):
        # testing : dst_reg <- constant
        binary = "\x00\x20\x02\xe2"  # and    r2, r2, #0
        binary += "\x21\x20\x82\xe3"  # orr    r2, r2, #33
        binary += "\x1e\xff\x2f\xe1"  # bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources,
                          [ReilImmediateOperand(33, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(
            ReilRegisterOperand("r2", 32) in
            g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_add_1(self):
        # testing : dst_reg <- src1_reg + src2_reg
        binary = "\x08\x00\x84\xe0"  # 0x00 : (4)  add    r0, r4, r8
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.Arithmetic)
        self.assertEquals(
            g_classified[0].sources,
            [ReilRegisterOperand("r4", 32),
             ReilRegisterOperand("r8", 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[0].operation, "+")

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_sub_1(self):
        # testing : dst_reg <- src1_reg + src2_reg
        binary = "\x08\x00\x44\xe0"  # 0x00 : (4)  sub    r0, r4, r8
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.Arithmetic)
        self.assertEquals(
            g_classified[0].sources,
            [ReilRegisterOperand("r4", 32),
             ReilRegisterOperand("r8", 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[0].operation, "-")

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_memory_1(self):
        # testing : dst_reg <- m[src_reg]
        binary = "\x00\x30\x94\xe5"  # 0x00 : (4)  ldr    r3, [r4]
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadMemory)
        self.assertEquals(
            g_classified[0].sources,
            [ReilRegisterOperand("r4", 32),
             ReilImmediateOperand(0x0, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r3", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_memory_2(self):
        # testing : dst_reg <- m[src_reg + offset]
        binary = "\x33\x30\x94\xe5"  # 0x00 : (4)  ldr    r3, [r4 + 0x33]
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadMemory)
        self.assertEquals(
            g_classified[0].sources,
            [ReilRegisterOperand("r4", 32),
             ReilImmediateOperand(0x33, 32)])
        self.assertEquals(g_classified[0].destination,
                          [ReilRegisterOperand("r3", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    # TODO: ARM's ldr rd, [rn, r2] is not a valid classification right now

    def test_store_memory_1(self):
        # testing : dst_reg <- m[src_reg]
        binary = "\x00\x30\x84\xe5"  # 0x00 : (4)  str    r3, [r4]
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.StoreMemory)
        self.assertEquals(g_classified[0].sources,
                          [ReilRegisterOperand("r3", 32)])
        self.assertEquals(
            g_classified[0].destination,
            [ReilRegisterOperand("r4", 32),
             ReilImmediateOperand(0x0, 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_store_memory_2(self):
        # testing : dst_reg <- m[src_reg + offset]
        binary = "\x33\x30\x84\xe5"  # 0x00 : (4)  str    r3, [r4 + 0x33]
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.StoreMemory)
        self.assertEquals(g_classified[0].sources,
                          [ReilRegisterOperand("r3", 32)])
        self.assertEquals(
            g_classified[0].destination,
            [ReilRegisterOperand("r4", 32),
             ReilImmediateOperand(0x33, 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_load_add_1(self):
        # testing : dst_reg <- dst_reg + mem[src_reg]
        binary = "\x00\x30\x94\xe5"  # 0x00 : (4)  ldr    r3, [r4]
        binary += "\x03\x00\x80\xe0"  # 0x00 : (4)  add    r0, r0, r3
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticLoad)
        self.assertEquals(g_classified[1].sources, [
            ReilRegisterOperand("r0", 32),
            ReilRegisterOperand("r4", 32),
            ReilImmediateOperand(0x0, 32)
        ])
        self.assertEquals(g_classified[1].destination,
                          [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(
            ReilRegisterOperand("r0", 32) in
            g_classified[1].modified_registers)
        self.assertTrue(
            ReilRegisterOperand("r3", 32) in
            g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def test_arithmetic_load_add_2(self):
        # testing : dst_reg <- dst_reg + mem[src_reg + offset]
        binary = "\x22\x30\x94\xe5"  # 0x00 : (4)  ldr    r3, [r4, 0x22]
        binary += "\x03\x00\x80\xe0"  # 0x00 : (4)  add    r0, r0, r3
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticLoad)
        self.assertEquals(g_classified[1].sources, [
            ReilRegisterOperand("r0", 32),
            ReilRegisterOperand("r4", 32),
            ReilImmediateOperand(0x22, 32)
        ])
        self.assertEquals(g_classified[1].destination,
                          [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(
            ReilRegisterOperand("r0", 32) in
            g_classified[1].modified_registers)
        self.assertTrue(
            ReilRegisterOperand("r3", 32) in
            g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def test_arithmetic_store_add_1(self):
        # testing : m[dst_reg] <- m[dst_reg] + src_reg
        binary = "\x00\x30\x94\xe5"  # 0x00 : (4)  ldr    r3, [r4]
        binary += "\x03\x30\x80\xe0"  # 0x00 : (4)  add    r3, r0, r3
        binary += "\x00\x30\x84\xe5"  # 0x00 : (4)  str    r3, [r4]
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticStore)
        self.assertEquals(g_classified[1].sources, [
            ReilRegisterOperand("r4", 32),
            ReilImmediateOperand(0x0, 32),
            ReilRegisterOperand("r0", 32)
        ])
        self.assertEquals(
            g_classified[1].destination,
            [ReilRegisterOperand("r4", 32),
             ReilImmediateOperand(0x0, 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(
            ReilRegisterOperand("r4", 32) in
            g_classified[1].modified_registers)
        self.assertTrue(
            ReilRegisterOperand("r3", 32) in
            g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def test_arithmetic_store_add_2(self):
        # testing : dst_reg <- dst_reg + mem[src_reg + offset]
        binary = "\x22\x30\x94\xe5"  # 0x00 : (4)  ldr    r3, [r4, 0x22]
        binary += "\x03\x30\x80\xe0"  # 0x00 : (4)  add    r3, r0, r3
        binary += "\x22\x30\x84\xe5"  # 0x00 : (4)  str    r3, [r4, 0x22]
        binary += "\x1e\xff\x2f\xe1"  # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticStore)
        self.assertEquals(g_classified[1].sources, [
            ReilRegisterOperand("r4", 32),
            ReilImmediateOperand(0x22, 32),
            ReilRegisterOperand("r0", 32)
        ])
        self.assertEquals(
            g_classified[1].destination,
            [ReilRegisterOperand("r4", 32),
             ReilImmediateOperand(0x22, 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(
            ReilRegisterOperand("r4", 32) in
            g_classified[1].modified_registers)
        self.assertTrue(
            ReilRegisterOperand("r3", 32) in
            g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def _print_candidates(self, candidates):
        print "Candidates :"

        for gadget in candidates:
            print gadget
            print "-" * 10

    def _print_classified(self, classified):
        print "Classified :"

        for gadget in classified:
            print gadget
            print gadget.type
            print "-" * 10
Exemple #17
0
class GadgetTools():
    def __init__ (self, binary):
        self.elf = elffile.ELFFile(binary)

        if self.elf.elfclass == 32:
            self.arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)
        if self.elf.elfclass == 64:
            self.arch_info = X86ArchitectureInformation(ARCH_X86_MODE_64)

        self.emulator = ReilEmulator(self.arch_info.address_size)
        self.emulator.set_arch_registers(self.arch_info.registers_gp)
        self.emulator.set_arch_registers_size(self.arch_info.register_size)
        self.emulator.set_reg_access_mapper(self.arch_info.register_access_mapper())

        self.classifier = GadgetClassifier(self.emulator, self.arch_info)

        self.smt_solver = SmtSolver()
        self.smt_translator = SmtTranslator(self.smt_solver, self.arch_info.address_size)

        self.smt_translator.set_reg_access_mapper(self.arch_info.register_access_mapper())
        self.smt_translator.set_arch_registers_size(self.arch_info.register_size)

        self.code_analyzer = CodeAnalyzer(self.smt_solver, self.smt_translator)

        self.gadgets = {}
        self.classified_gadgets = {}

        self.regset = RegSet(self)
        self.ccf = CCFlag(self)
        self.ams = ArithmeticStore(self)
        self.memstr = MemoryStore(self)

        self.reil_translator = X86Translator(architecture_mode=self.arch_info.architecture_mode,
                                                  translation_mode=FULL_TRANSLATION)




    def find_gadgets(self, max_instr=10, max_bytes=15):
        logging.info('searching gadgets in binary..')
        for s in self.elf.iter_sections():
            if (s.header.sh_type == 'SHT_PROGBITS') and (s.header.sh_flags & 0x4):
                sz = s.header.sh_size
                base = s.header.sh_addr
                mem = Memory(lambda x, y : s.data()[x - base], None)
                gfinder = GadgetFinder(X86Disassembler(architecture_mode=self.arch_info.architecture_mode),
                                    mem,
                                    self.reil_translator)

                logging.info("searching gadgets in section " + s.name + "...")

                for g in gfinder.find(base, base + sz - 1, max_instr, max_bytes):
                    ret = g.instrs[-1].asm_instr
                    if not isinstance(ret, Ret):
                        continue
                    if len(ret.operands) > 0 and ret.operands[0].immediate > 0x10:
                        continue

                    self.gadgets[g.address] = g

                logging.info("found {0} gadgets".format(len(self.gadgets)))

    def classify_gadgets(self):
        #setting 0 to cf flags to avoid impossible random value invalidates results

        rflags =  self.classifier.set_reg_init({'cf': 0})

        for g in self.gadgets.itervalues():
            tgs = self.classifier.classify(g)
            for tg in tgs:
                if tg.type not in self.classified_gadgets:
                    self.classified_gadgets[tg.type] = []

                self.classified_gadgets[tg.type].append(tg)



    def find_reg_set_gadgets(self):
        self.regset.add(self.gadgets.itervalues())


    def read_carrier_flag(self, g):
        self.emulator.reset()
        regs_init = utils.make_random_regs_context(self.arch_info)

        self.emulator.execute_lite(g.get_ir_instrs(), regs_init)

        return 'eflags' in self.emulator.registers


    def find_arithmetic_mem_set_gadgets(self):
        self.ams.add(self.classified_gadgets[GadgetType.ArithmeticStore])

    def get_stack_slide_chunk(self, slide):
        #TODO not use only pop gadgets and chain more slide if we wan't longer slide
        return self.regset.get_slide_stack_chunk(slide)


    def get_ret_func_chunk(self, args, address):
        """Return a chainable chunk that return to address and set up args or registers as if a function was called with args.

        Args:

        args (list): the list of args to setup as function arguments
        address (int): the address where to return

        """
        if self.arch_info.architecture_size == 64:
            if len(args) > 6:
                raise BaseException("chunk for calling a function whit more of six args isn't implemented")

            args_regs = ['rdi', 'rsi', 'rdx', 'rcx', 'r8', 'r9']
            regs_values = {args_regs[i] : a for i, a in enumerate(args)}
            regs_c = self.regset.get_chunk(regs_values)
            ret_c = PayloadChunk("", self.arch_info, address)
            return PayloadChunk.get_general_chunk([regs_c, ret_c])

        if self.arch_info.architecture_size == 32:
            slide_c = self.regset.get_slide_stack_chunk(len(args) * 4)
            print slide_c
            ret_c = RetToAddress32(args, address, self.arch_info)
            print ret_c
            return PayloadChunk.get_general_chunk([ret_c, slide_c])

    def get_mem_set_libc_read_chunk(self, location, fd, size, read_address):
        if self.arch_info.architecture_size == 64:
            print "TO IMPLEMENT"
            return


        if self.arch_info.architecture_size == 32:
            slide_chunk = self.regset.get_slide_stack_chunk(4 * 3)
            pl_chunk = MemSetLibcRead32(location,
                                        fd,
                                        size,
                                        read_address,
                                        self.arch_info)


        return PayloadChunk.get_general_chunk([pl_chunk, slide_chunk])

    def build_mem_add(self, location, offset, size, mem_pre = None):
        return self.ams.get_memory_add_chunk(location, offset, size, mem_pre)


    def check_mem_side_effects_and_stack_end(self, g, regs_init, location, size):
        stack_reg = 'esp'
        if (self.arch_info.architecture_size == 64):
            stack_reg = 'rsp'

        stack_base = 0x50
        regs_init[stack_reg] = stack_base

        #TODO fix try and execute (zero div in mv where finding ccf)
        try:
            cregs, mem_final = self.emulator.execute_lite(g.get_ir_instrs(), regs_init)
        except:
            pass

        mem_side_effects = []

        for addr in mem_final.get_addresses():
            if addr in [location + i for i in xrange(size/8)]:
                continue

            sp, vp = mem_final.try_read_prev(addr, 8)
            sn, vn = mem_final.try_read(addr, 8)

            #quick fix. We should disting between reading from stack and read side effetcs
            if sn and not sp and (addr >= stack_base - abs(stack_base - cregs[stack_reg]) and addr <= stack_base + abs(stack_base - cregs[stack_reg])):
                continue

            if (sp and sn and vp != vn) or (sn and not sp) :
                mem_side_effects.append(addr)

        return mem_side_effects, cregs[stack_reg] - stack_base - self.arch_info.address_size / 8

    def find_memory_store(self):
        self.mem_set_gadgets = {}

        if not GadgetType.StoreMemory in self.classified_gadgets:
            return

        self.memstr.add(self.classified_gadgets[GadgetType.StoreMemory])

    def find_ccfs(self):
        self.ccf.add(self.gadgets.values())
Exemple #18
0
class X86TranslationTests(unittest.TestCase):

    def setUp(self):
        self.trans_mode = FULL_TRANSLATION

        self.arch_mode = ARCH_X86_MODE_64

        self.arch_info = X86ArchitectureInformation(self.arch_mode)

        self.x86_parser = X86Parser(self.arch_mode)
        self.x86_translator = X86Translator(self.arch_mode, self.trans_mode)
        self.smt_solver = SmtSolver()
        self.smt_translator = SmtTranslator(self.smt_solver, self.arch_info.address_size)
        self.reil_emulator = ReilEmulator(self.arch_info.address_size)

        self.reil_emulator.set_arch_registers(self.arch_info.registers_gp)
        self.reil_emulator.set_arch_registers_size(self.arch_info.register_size)
        self.reil_emulator.set_reg_access_mapper(self.arch_info.register_access_mapper())

        self.smt_translator.set_reg_access_mapper(self.arch_info.register_access_mapper())
        self.smt_translator.set_arch_registers_size(self.arch_info.register_size)

    def test_lea(self):
        asm = ["lea eax, [ebx + 0x100]"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_cld(self):
        asm = ["cld"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_clc(self):
        asm = ["clc"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_nop(self):
        asm = ["nop"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_test(self):
        asm = ["test eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_not(self):
        asm = ["not eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_xor(self):
        asm = ["xor eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_or(self):
        asm = ["or eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_and(self):
        asm = ["and eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_cmp(self):
        asm = ["cmp eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_neg(self):
        asm = ["neg eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_dec(self):
        asm = ["dec eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_inc(self):
        asm = ["inc eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_div(self):
        asm = ["div ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = {
            'rax'    : 0x10,
            'rbx'    : 0x2,
            'rdx'    : 0x0,
            'rflags' : 0x202,
        }

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_imul(self):
        asm = ["imul eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_mul(self):
        asm = ["mul ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_sbb(self):
        asm = ["sbb eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_sub(self):
        asm = ["sub eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_adc(self):
        asm = ["adc eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_add(self):
        asm = ["add eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_xchg(self):
        asm = ["xchg eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_movzx(self):
        asm = ["movzx eax, bx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_mov(self):
        asm = ["mov eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_shr(self):
        asm = ["shr eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_shl(self):
        asm = ["shl eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_sal(self):
        asm = ["sal eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def test_sar(self):
        asm = ["sar eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)
        x86_instrs[0].address = 0xdeadbeef

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        context_init = self.__init_context()

        x86_rv, x86_context_out  = pyasmjit.execute("\n".join(asm), context_init)
        reil_context_out, reil_memory_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, context=self.__update_flags_from_rflags(context_init))

        self.assertTrue(self.__compare_contexts(context_init, x86_context_out, reil_context_out))

    def __init_context(self):
        return {
            'rax'    : 0xa,
            'rbx'    : 0x2,
            'rcx'    : 0xb,
            'rdx'    : 0xc,
            'rdi'    : 0xd,
            'rsi'    : 0xe,
            'rflags' : 0x202,
        }

    def __compare_contexts(self, context_init, x86_context, reil_context):
        match = True

        for reg in sorted(context_init.keys()):
            if ((2**64-1) & x86_context[reg]) != ((2**64-1) & reil_context[reg]):
                print ("%s : %s " % (reg, hex((2**64-1) & x86_context[reg])))
                print ("%s : %s " % (reg, hex((2**64-1) & reil_context[reg])))
                match = False
                break

        return match

    def __print_contexts(self, context_init, x86_context, reil_context):
        header_fmt = "{0:^8s} : {1:>16s} ?= {2:<16s}"
        header = header_fmt.format("Register", "x86", "REIL")
        ruler  = "-" * len(header)

        print(header)
        print(ruler)

        for reg in sorted(context_init.keys()):
            if  ((2**64-1) & x86_context[reg]) != ((2**64-1) & reil_context[reg]):
                eq = "!="
                marker = "<"
            else:
                eq = "=="
                marker = ""

            fmt = "{0:>8s} : {1:016x} {eq} {2:016x} ({1:>5d} {eq} {2:<5d}) {marker}"

            print fmt.format(
                reg,
                (2**64-1) & x86_context[reg],
                (2**64-1) & reil_context[reg],
                eq=eq,
                marker=marker
            )

    def __update_rflags(self, reil_context_out, x86_context_out):

        reil_context = dict((reg, value) for reg, value in reil_context_out.items() if reg in x86_context_out.keys())

        reil_context['rflags'] = 0xffffffff & (
            0x0                      << 31 | # Reserved
            0x0                      << 30 | # Reserved
            0x0                      << 29 | # Reserved
            0x0                      << 28 | # Reserved
            0x0                      << 27 | # Reserved
            0x0                      << 26 | # Reserved
            0x0                      << 25 | # Reserved
            0x0                      << 24 | # Reserved
            0x0                      << 23 | # Reserved
            0x0                      << 22 | # Reserved
            0x0                      << 21 | # ID
            0x0                      << 20 | # VIP
            0x0                      << 19 | # VIF
            0x0                      << 18 | # AC
            0x0                      << 17 | # VM
            0x0                      << 16 | # RF
            0x0                      << 15 | # Reserved
            0x0                      << 14 | # NT
            0x0                      << 13 | # IOPL
            0x0                      << 12 | # IOPL
            reil_context_out['of']   << 11 | # OF
            reil_context_out['df']   << 10 | # DF
            0x1                      <<  9 | # IF
            0x0                      <<  8 | # TF
            reil_context_out['sf']   <<  7 | # SF
            reil_context_out['zf']   <<  6 | # ZF
            0x0                      <<  5 | # Reserved
            # reil_context_out['af'] <<  4 | # AF
            (x86_context_out['rflags'] & 0x10) | # AF
            0x0                      <<  3 | # Reserved
            # reil_context_out['pf'] <<  2 | # PF
            (x86_context_out['rflags'] & 0x4)  | # PF
            0x1                      <<  1 | # Reserved
            reil_context_out['cf']   <<  0   # CF
        )

        return reil_context

    def __update_flags_from_rflags(self, reil_context):
        reil_context_out = dict(reil_context)

        flags_reg = None

        if 'rflags' in reil_context_out:
            flags_reg = 'rflags'

        if 'eflags' in reil_context_out:
            flags_reg = 'eflags'

        if flags_reg:
            reil_context_out['of'] = reil_context_out[flags_reg] & 2**11 # OF
            reil_context_out['df'] = reil_context_out[flags_reg] & 2**10 # DF
            reil_context_out['sf'] = reil_context_out[flags_reg] & 2**7  # SF
            reil_context_out['zf'] = reil_context_out[flags_reg] & 2**6  # ZF
            reil_context_out['af'] = reil_context_out[flags_reg] & 2**4  # AF
            reil_context_out['pf'] = reil_context_out[flags_reg] & 2**2  # PF
            reil_context_out['cf'] = reil_context_out[flags_reg] & 2**0  # CF

        return reil_context_out
class ArmTranslationTests(unittest.TestCase):
    def setUp(self):
        self.trans_mode = FULL_TRANSLATION
        self.arch_mode = ARCH_ARM_MODE_32

        self.arch_info = ArmArchitectureInformation(self.arch_mode)

        self.arm_parser = ArmParser(self.arch_mode)
        self.arm_translator = ArmTranslator(self.arch_mode, self.trans_mode)
        self.reil_emulator = ReilEmulator(self.arch_info)

        self.context_filename = "failing_context.data"

    def __init_context(self):
        """Initialize register with random values.
        """
        context = self.__create_random_context()

        return context

    def __create_random_context(self):
        context = {}

        for reg in self.arch_info.registers_gp_base:
            if reg not in ['r13', 'r14', 'r15']:
                min_value, max_value = 0, 2**self.arch_info.operand_size - 1
                context[reg] = random.randint(min_value, max_value)

        # Only highest 4 bits (N, C, Z, V) are randomized, the rest are left in the default (user mode) value
        context['apsr'] = 0x00000010
        context['apsr'] |= random.randint(0x0, 0xF) << 28

        return context

    def __load_failing_context(self):
        f = open(self.context_filename, "rb")
        context = pickle.load(f)
        f.close()

        return context

    def __save_failing_context(self, context):
        f = open(self.context_filename, "wb")
        pickle.dump(context, f)
        f.close()

    def __compare_contexts(self, context_init, arm_context, reil_context):
        match = True
        mask = 2**32 - 1

        for reg in sorted(context_init.keys()):
            if (arm_context[reg] & mask) != (reil_context[reg] & mask):
                match = False
                break

        return match

    def __print_contexts(self, context_init, arm_context, reil_context):
        out = "Contexts don't match!\n\n"

        header_fmt = " {0:^8s} : {1:^16s} | {2:>16s} ?= {3:<16s}\n"
        header = header_fmt.format("Register", "Initial", "ARM", "REIL")
        ruler = "-" * len(header) + "\n"

        out += header
        out += ruler

        fmt = " {0:>8s} : {1:08x} | {2:08x} {eq} {3:08x} {marker}\n"

        mask = 2**64 - 1

        for reg in sorted(context_init.keys()):
            if (arm_context[reg] & mask) != (reil_context[reg] & mask):
                eq, marker = "!=", "<"
            else:
                eq, marker = "==", ""

            out += fmt.format(reg,
                              context_init[reg] & mask,
                              arm_context[reg] & mask,
                              reil_context[reg] & mask,
                              eq=eq,
                              marker=marker)

        # Pretty print flags.
        reg = "apsr"
        fmt = "{0:s} ({1:>4s}) : {2:08x} ({3:s})"

        arm_value = arm_context[reg] & mask
        reil_value = reil_context[reg] & mask

        if arm_value != reil_value:
            arm_flags_str = self.__print_flags(arm_context[reg])
            reil_flags_str = self.__print_flags(reil_context[reg])

            out += "\n"
            out += fmt.format(reg, "ARM", arm_value, arm_flags_str) + "\n"
            out += fmt.format(reg, "reil", reil_value, reil_flags_str)

        return out

    def __print_flags(self, flags_reg):
        # flags
        flags = {
            31: "nf",  # bit 31
            30: "zf",  # bit 30
            29: "cf",  # bit 29
            28: "vf",  # bit 28
        }

        out = ""

        for bit, flag in flags.items():
            flag_str = flag.upper() if flags_reg & 2**bit else flag.lower()
            out += flag_str + " "

        return out[:-1]

    def __fix_reil_flag(self, reil_context, arm_context, flag):
        reil_context_out = dict(reil_context)

        flags_reg = 'eflags' if 'eflags' in reil_context_out else 'rflags'

        arch_size = self.arch_info.architecture_size

        _, bit = self.arch_info.alias_mapper[flag]

        # Clean flag.
        reil_context_out[flags_reg] &= ~(2**bit) & (2**32 - 1)

        # Copy flag.
        reil_context_out[flags_reg] |= (arm_context[flags_reg] & 2**bit)

        return reil_context_out

    def __fix_reil_flags(self, reil_context, arm_context):
        # Remove this when AF and PF are implemented.
        reil_context_out = self.__fix_reil_flag(reil_context, arm_context,
                                                "af")
        reil_context_out = self.__fix_reil_flag(reil_context, arm_context,
                                                "pf")

        return reil_context_out

    def __set_address(self, address, arm_instrs):
        addr = address

        for arm_instr in arm_instrs:
            arm_instr.address = addr
            arm_instr.size = 4
            addr += 4

    def _test_asm_instruction(self, asm):
        print(asm)

        arm_instrs = map(self.arm_parser.parse, asm)

        self.__set_address(0xdeadbeef, arm_instrs)

        reil_instrs = map(self.arm_translator.translate, arm_instrs)

        ctx_init = self.__init_context()

        arm_rv, arm_ctx_out, _ = pyasmjit.arm_execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs, 0xdeadbeef << 8, context=ctx_init)

        cmp_result = self.__compare_contexts(ctx_init, arm_ctx_out,
                                             reil_ctx_out)

        if not cmp_result:
            self.__save_failing_context(ctx_init)

        self.assertTrue(
            cmp_result,
            self.__print_contexts(ctx_init, arm_ctx_out, reil_ctx_out))

    def _execute_asm(self, asm, ini_addr=0x8000):
        print(asm)

        arm_instrs = map(self.arm_parser.parse, asm)

        self.__set_address(ini_addr, arm_instrs)

        reil_instrs = map(self.arm_translator.translate, arm_instrs)

        ctx_init = self.__init_context()

        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs, 0xdeadbeef << 8, context=ctx_init)

        return reil_ctx_out

    # R11 is used as a dirty register to check if the branch was taken or not
    def test_asm_branch_instruction(self):
        untouched_value = 0x45454545
        touched_value = 0x31313131

        inst_samples_touched = [
            [
                "mov r11, #0x{:x}".format(untouched_value),
                "b #0x800c",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
            [
                "mov r11, #0x{:x}".format(untouched_value),
                "bx #0x800c",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
            [
                "mov r11, #0x{:x}".format(untouched_value),
                "bl #0x800c",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
            [
                "mov r11, #0x{:x}".format(untouched_value),
                "blx #0x800c",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
            [
                "movs r11, #0x{:x}".format(untouched_value),
                "bne #0x800c",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
            [
                "mov r11, #0x{:x}".format(untouched_value),
                "mov r1, #0x8010",
                "bx r1",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
            [
                "mov r11, #0x{:x}".format(untouched_value),
                "mov r1, #0x8010",
                "blx r1",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
        ]

        for asm in inst_samples_touched:
            reil_ctx_out = self._execute_asm(asm, 0x8000)
            self.assertTrue(reil_ctx_out['r11'] == untouched_value)


#             print reil_ctx_out

# TODO: Merge with previous test function

    def _test_asm_instruction_with_mem(self, asm, reg_mem):
        print(asm)

        mem_dir = pyasmjit.arm_alloc(4096)

        arm_instrs = map(self.arm_parser.parse, asm)

        self.__set_address(0xdeadbeef, arm_instrs)

        reil_instrs = map(self.arm_translator.translate, arm_instrs)

        ctx_init = self.__init_context()

        ctx_init[reg_mem] = mem_dir

        arm_rv, arm_ctx_out, arm_mem_out = pyasmjit.arm_execute(
            "\n".join(asm), ctx_init)

        self.reil_emulator._mem._memory = {
        }  # TODO: Check how to clean emulator memory.

        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs, 0xdeadbeef << 8, context=ctx_init)

        base_dir = mem_dir

        for idx, b in enumerate(
                struct.unpack("B" * len(arm_mem_out), arm_mem_out)):
            if (
                    base_dir + idx
            ) in reil_mem_out._memory:  # TODO: Don't access variable directly.
                self.assertTrue(b == reil_mem_out._memory[base_dir + idx])
            else:
                self.assertTrue(
                    b == 0x0)  # Memory in pyasmjit is initialized to 0

        cmp_result = self.__compare_contexts(ctx_init, arm_ctx_out,
                                             reil_ctx_out)

        if not cmp_result:
            self.__save_failing_context(ctx_init)

        self.assertTrue(
            cmp_result,
            self.__print_contexts(ctx_init, arm_ctx_out, reil_ctx_out))

        pyasmjit.arm_free(
        )  # There is only one memory pool, so there is no need (for now) to specify the address

    # TODO: R13 (SP), R14 (LR), R15 (PC) are outside testing scope for now

    def test_data_proc_inst(self):
        inst_samples = [
            # No flags
            ["mov r0, r1"],
            ["mov r3, r8"],
            ["mov r5, r8"],
            ["and r0, r1, r2"],
            ["and r0, r6, #0x33"],
            ["orr r3, r5, r8"],
            ["orr r3, r5, #0x79"],
            ["orr r3, r5, r8, lsl #0x19"],
            ["eor r3, r5, r8"],
            ["eor r8, r4, r5, lsl r6"],
            ["eor r8, r4, r5, lsl #0x11"],
            ["add r8, r9, r11"],
            ["sub r0, r3, r12"],
            ["cmp r3, r12"],
            ["cmn r3, r12"],
            ['mov r8, r5, lsl r6'],
            ['eor r8, r4, r5, lsl r6'],
            ['mul r3, r4, r8'],
            ["mov r8, #0", 'mul r3, r4, r8'],
            ['mul r3, r4, r4'],

            # Flags update
            ["movs r0, #0"],
            ["movs r0, #-10"],
            ["mov r0, #0x7FFFFFFF", "mov r1, r0", "adds r3, r0, r1"],
            ["mov r0, #0x7FFFFFFF", "mov r1, r0", "subs r3, r0, r1"],
            ["mov r0, #0x00FFFFFF", "add r1, r0, #10", "subs r3, r0, r1"],
            ["mov r0, #0xFFFFFFFF", "adds r3, r0, #10"],
            ["mov r0, #0x7FFFFFFF", "mov r1,  #5", "adds r3, r0, r1"],
            ["mov r0, #0x80000000", "mov r1,  #5", "subs r3, r0, r1"],

            # Flags evaluation
            ["moveq r0, r1"],
            ["movne r3, r8"],
            ["movcs r5, r8"],
            ["andcc r0, r1, r2"],
            ["andmi r0, r6, #0x33"],
            ["orrpl r3, r5, r8"],
            ["orrvs r3, r5, #0x79"],
            ["orrvc r3, r5, r8, lsl #0x19"],
            ["eorhi r3, r5, r8"],
            ["eorls r8, r4, r5, lsl r6"],
            ["eorge r8, r4, r5, lsl #0x11"],
            ["addlt r8, r9, r11"],
            ["subgt r0, r3, r12"],
            ["cmple r3, r12"],
            ["cmnal r3, r12"],
            ["addhs r8, r9, r11"],
            ["sublo r0, r3, r12"],
        ]

        for i in inst_samples:
            self._test_asm_instruction(i)

    # R12 is loaded with the memory address
    def test_mem_inst(self):
        inst_samples = [
            ["str r0, [r12]", "ldr  r1, [r12]"],
            ["stm r12!, {r0 - r4}", "ldmdb r12, {r5 - r9}"],
            ["stmia r12, {r8 - r9}", "ldmia r12, {r1 - r2}"],
            ["stmib r12!, {r11}", "ldmda r12!, {r3}"],
            [
                "add r12, r12, #0x100", "stmda r12, {r9 - r11}",
                "ldmda r12, {r1 - r3}"
            ],
            ["add r12, r12, #0x100", "stmdb r12!, {r3}", "ldmia r12!, {r9}"],
            ["add r12, r12, #0x100", "stmfd r12, {r2 - r4}"],
            ["stmfa r12!, {r1 - r4}"],
            ["add r12, r12, #0x100", "stmed r12, {r6 - r9}"],
            ["stmea r12!, {r5 - r7}"],
            [
                "mov r0, r13", "mov r13, r12", "push {r1 - r10}",
                "pop {r2 - r11}", "mov r13, r0", "mov r0, #0"
            ],  # The last inst. is needed because the emulator has no access to the real value of native r13 (SP) which is not passed in the context
            [
                "mov r0, r13", "mov r13, r12", "push {r2 - r11}",
                "pop {r1 - r10}", "mov r13, r0", "mov r0, #0"
            ],
        ]

        for i in inst_samples:
            self._test_asm_instruction_with_mem(i, 'r12')
class ArmTranslationTests(unittest.TestCase):

    def setUp(self):
        self.trans_mode = FULL_TRANSLATION
        self.arch_mode = ARCH_ARM_MODE_32
        self.arch_info = ArmArchitectureInformation(self.arch_mode)
        self.arm_parser = ArmParser(self.arch_mode)
        self.arm_translator = ArmTranslator(self.arch_mode, self.trans_mode)
        self.reil_emulator = ReilEmulator(self.arch_info)

        self.context_filename = "failing_context.data"

    def __init_context(self):
        """Initialize register with random values.
        """
        context = self.__create_random_context()

        return context

    def __create_random_context(self):
        context = {}

        for reg in self.arch_info.registers_gp_base:
            if reg not in ['r13', 'r14', 'r15']:
                min_value, max_value = 0, 2**self.arch_info.operand_size - 1
                context[reg] = random.randint(min_value, max_value)

        # Only highest 4 bits (N, C, Z, V) are randomized, the rest are
        # left in the default (user mode) value.
        context['apsr'] = 0x00000010
        context['apsr'] |= random.randint(0x0, 0xF) << 28

        return context

    def __load_failing_context(self):
        f = open(self.context_filename, "rb")
        context = pickle.load(f)
        f.close()

        return context

    def __save_failing_context(self, context):
        f = open(self.context_filename, "wb")
        pickle.dump(context, f)
        f.close()

    def __compare_contexts(self, context_init, arm_context, reil_context):
        match = True
        mask = 2**32-1

        for reg in sorted(context_init.keys()):
            if (arm_context[reg] & mask) != (reil_context[reg] & mask):
                match = False
                break

        return match

    def __print_contexts(self, context_init, arm_context, reil_context):
        out = "Contexts don't match!\n\n"

        header_fmt = " {0:^8s} : {1:^16s} | {2:>16s} ?= {3:<16s}\n"
        header = header_fmt.format("Register", "Initial", "ARM", "REIL")
        ruler = "-" * len(header) + "\n"

        out += header
        out += ruler

        fmt = " {0:>8s} : {1:08x} | {2:08x} {eq} {3:08x} {marker}\n"

        mask = 2**64-1

        for reg in sorted(context_init.keys()):
            if (arm_context[reg] & mask) != (reil_context[reg] & mask):
                eq, marker = "!=", "<"
            else:
                eq, marker = "==", ""

            out += fmt.format(
                reg,
                context_init[reg] & mask,
                arm_context[reg] & mask,
                reil_context[reg] & mask,
                eq=eq,
                marker=marker
            )

        # Pretty print flags.
        reg = "apsr"
        fmt = "{0:s} ({1:>4s}) : {2:08x} ({3:s})"

        arm_value = arm_context[reg] & mask
        reil_value = reil_context[reg] & mask

        if arm_value != reil_value:
            arm_flags_str = self.__print_flags(arm_context[reg])
            reil_flags_str = self.__print_flags(reil_context[reg])

            out += "\n"
            out += fmt.format(reg, "ARM", arm_value, arm_flags_str) + "\n"
            out += fmt.format(reg, "reil", reil_value, reil_flags_str)

        return out

    def __print_flags(self, flags_reg):
        # flags
        flags = {
             31 : "nf",  # bit 31
             30 : "zf",  # bit 30
             29 : "cf",  # bit 29
             28 : "vf",  # bit 28
        }

        out = ""

        for bit, flag in flags.items():
            flag_str = flag.upper() if flags_reg & 2**bit else flag.lower()
            out += flag_str + " "

        return out[:-1]

    def __set_address(self, address, arm_instrs):
        addr = address

        for arm_instr in arm_instrs:
            arm_instr.address = addr
            arm_instr.size = 4
            addr += 4

    def __translate(self, asm_instrs):
        instr_container = ReilContainer()

        asm_instr_last = None
        instr_seq_prev = None

        for asm_instr in asm_instrs:
            instr_seq = ReilSequence()

            for reil_instr in self.arm_translator.translate(asm_instr):
                instr_seq.append(reil_instr)

            if instr_seq_prev:
                instr_seq_prev.next_sequence_address = instr_seq.address

            instr_container.add(instr_seq)

            instr_seq_prev = instr_seq

        if instr_seq_prev:
            if asm_instr_last:
                instr_seq_prev.next_sequence_address = (asm_instr_last.address + asm_instr_last.size) << 8

        return instr_container

    def __asm_to_reil(self, asm_list, address):
        arm_instrs = [self.arm_parser.parse(asm) for asm in asm_list]

        self.__set_address(address, arm_instrs)

        reil_instrs = self.__translate(arm_instrs)

        return reil_instrs

    def __run_code(self, asm_list, address, ctx_init):
        reil_instrs = self.__asm_to_reil(asm_list, address)

        _, arm_ctx_out, _ = pyasmjit.arm_execute("\n".join(asm_list), ctx_init)
        reil_ctx_out, _ = self.reil_emulator.execute(reil_instrs, start=address << 8, registers=ctx_init)

        return arm_ctx_out, reil_ctx_out

    def __test_asm_instruction(self, asm_list):
        ctx_init = self.__init_context()

        arm_ctx_out, reil_ctx_out = self.__run_code(asm_list, 0xdeadbeef, ctx_init)

        cmp_result = self.__compare_contexts(ctx_init, arm_ctx_out, reil_ctx_out)

        if not cmp_result:
            self.__save_failing_context(ctx_init)

        self.assertTrue(cmp_result, self.__print_contexts(ctx_init, arm_ctx_out, reil_ctx_out))

    def __test_asm_instruction_with_mem(self, asm_list, address_register):
        # TODO: Merge with previous test function.

        mem_addr = pyasmjit.arm_alloc(4096)

        self.reil_emulator.reset()

        reil_instrs = self.__asm_to_reil(asm_list, 0xdeadbeef)

        ctx_init = self.__init_context()
        ctx_init[address_register] = mem_addr

        _, arm_ctx_out, arm_mem_out = pyasmjit.arm_execute("\n".join(asm_list), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, registers=ctx_init)

        base_addr = mem_addr

        for idx, b in enumerate(struct.unpack("B" * len(arm_mem_out), arm_mem_out)):
            addr = base_addr + idx

            # TODO: Don't access variable directly.
            if addr in reil_mem_out._memory:
                self.assertTrue(b == reil_mem_out.read(addr, 1))
            else:
                # Memory in pyasmjit is initialized to 0.
                self.assertTrue(b == 0x0)

        cmp_result = self.__compare_contexts(ctx_init, arm_ctx_out, reil_ctx_out)

        if not cmp_result:
            self.__save_failing_context(ctx_init)

        self.assertTrue(cmp_result, self.__print_contexts(ctx_init, arm_ctx_out, reil_ctx_out))

        # NOTE: There is only one memory pool, so there is no need
        # (for now) to specify the address.
        pyasmjit.arm_free()

    def __execute_asm(self, asm_list, address=0x8000):
        reil_instrs = self.__asm_to_reil(asm_list, address)

        ctx_init = self.__init_context()

        reil_ctx_out, _ = self.reil_emulator.execute(reil_instrs, address << 8, registers=ctx_init)

        return reil_ctx_out

    def test_data_instructions(self):
        # TODO: R13 (SP), R14 (LR), R15 (PC) are outside testing scope
        # for now.

        instr_samples = [
            # No flags
            ["mov r0, r1"],
            ["mov r3, r8"],
            ["mov r5, r8"],
            ["and r0, r1, r2"],
            ["and r0, r6, #0x33"],
            ["orr r3, r5, r8"],
            ["orr r3, r5, #0x79"],
            ["orr r3, r5, r8, lsl #0x19"],
            ["eor r3, r5, r8"],
            ["eor r8, r4, r5, lsl r6"],
            ["eor r8, r4, r5, lsl #0x11"],
            ["add r8, r9, r11"],
            ["sub r0, r3, r12"],
            ["cmp r3, r12"],
            ["cmn r3, r12"],
            ["mov r8, r5, lsl r6"],
            ["eor r8, r4, r5, lsl r6"],
            ["mul r3, r4, r8"],
            ["mov r8, #0",
             "mul r3, r4, r8"],
            ["mul r3, r4, r4"],

            # Flags update
            ["movs r0, #0"],
            ["movs r0, #-10"],
            ["mov r0, #0x7FFFFFFF",
             "mov r1, r0",
             "adds r3, r0, r1"],
            ["mov r0, #0x7FFFFFFF",
             "mov r1, r0",
             "subs r3, r0, r1"],
            ["mov r0, #0x00FFFFFF",
             "add r1, r0, #10",
             "subs r3, r0, r1"],
            ["mov r0, #0xFFFFFFFF",
             "adds r3, r0, #10"],
            ["mov r0, #0x7FFFFFFF",
             "mov r1,  #5",
             "adds r3, r0, r1"],
            ["mov r0, #0x80000000",
             "mov r1,  #5",
             "subs r3, r0, r1"],

            # Flags evaluation
            ["moveq r0, r1"],
            ["movne r3, r8"],
            ["movcs r5, r8"],
            ["andcc r0, r1, r2"],
            ["andmi r0, r6, #0x33"],
            ["orrpl r3, r5, r8"],
            ["orrvs r3, r5, #0x79"],
            ["orrvc r3, r5, r8, lsl #0x19"],
            ["eorhi r3, r5, r8"],
            ["eorls r8, r4, r5, lsl r6"],
            ["eorge r8, r4, r5, lsl #0x11"],
            ["addlt r8, r9, r11"],
            ["subgt r0, r3, r12"],
            ["cmple r3, r12"],
            ["cmnal r3, r12"],
            ["addhs r8, r9, r11"],
            ["sublo r0, r3, r12"],
        ]

        for instr in instr_samples:
            self.__test_asm_instruction(instr)

    def test_mememory_instructions(self):
        # R12 is loaded with the memory address

        instr_samples = [
            ["str r0, [r12]",
             "ldr  r1, [r12]"],
            ["stm r12!, {r0 - r4}",
             "ldmdb r12, {r5 - r9}"],
            ["stmia r12, {r8 - r9}",
             "ldmia r12, {r1 - r2}"],
            ["stmib r12!, {r11}",
             "ldmda r12!, {r3}"],
            ["add r12, r12, #0x100",
             "stmda r12, {r9 - r11}",
             "ldmda r12, {r1 - r3}"],
            ["add r12, r12, #0x100",
             "stmdb r12!, {r3}",
             "ldmia r12!, {r9}"],
            ["add r12, r12, #0x100",
             "stmfd r12, {r2 - r4}"],
            ["stmfa r12!, {r1 - r4}"],
            ["add r12, r12, #0x100",
             "stmed r12, {r6 - r9}"],
            ["stmea r12!, {r5 - r7}"],
            # NOTE: The last instr. is needed because the emulator has
            # no access to the real value of native r13 (SP) which is
            # not passed in the context.
            ["mov r0, r13",
             "mov r13, r12",
             "push {r1 - r10}",
             "pop {r2 - r11}",
             "mov r13, r0",
             "mov r0, #0"],
            ["mov r0, r13",
             "mov r13, r12",
             "push {r2 - r11}",
             "pop {r1 - r10}",
             "mov r13, r0",
             "mov r0, #0"],
        ]

        for instr in instr_samples:
            self.__test_asm_instruction_with_mem(instr, 'r12')

    def test_branch_instructions(self):
        untouched_value = 0x45454545
        touched_value = 0x31313131

        # R11 is used as a dirty register to check if the branch was
        # taken or not.
        instr_samples = [
            ["mov r11, #0x{:x}".format(untouched_value),
             "b #0x800c",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
            ["mov r11, #0x{:x}".format(untouched_value),
             "bx #0x800c",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
            ["mov r11, #0x{:x}".format(untouched_value),
             "bl #0x800c",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
            ["mov r11, #0x{:x}".format(untouched_value),
             "blx #0x800c",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
            ["movs r11, #0x{:x}".format(untouched_value),
             "bne #0x800c",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
            ["mov r11, #0x{:x}".format(untouched_value),
             "mov r1, #0x8010",
             "bx r1",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
            ["mov r11, #0x{:x}".format(untouched_value),
             "mov r1, #0x8010",
             "blx r1",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
        ]

        for instr in instr_samples:
            reil_ctx_out = self.__execute_asm(instr, 0x8000)

            self.assertTrue(reil_ctx_out['r11'] == untouched_value)
class ReilEmulatorTests(unittest.TestCase):

    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._emulator = ReilEmulator(self._arch_info)

        self._asm_parser = X86Parser()
        self._reil_parser = ReilParser()

        self._translator = X86Translator()

    def test_add(self):
        asm_instrs  = self._asm_parser.parse("add eax, ebx")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax" : 0x1,
            "ebx" : 0x2,
        }

        regs_final, _ = self._emulator.execute_lite(
            reil_instrs,
            context=regs_initial
        )

        self.assertEqual(regs_final["eax"], 0x3)
        self.assertEqual(regs_final["ebx"], 0x2)

    def test_loop(self):
        # 0x08048060 : b8 00 00 00 00   mov eax,0x0
        # 0x08048065 : bb 0a 00 00 00   mov ebx,0xa
        # 0x0804806a : 83 c0 01         add eax,0x1
        # 0x0804806d : 83 eb 01         sub ebx,0x1
        # 0x08048070 : 83 fb 00         cmp ebx,0x0
        # 0x08048073 : 75 f5            jne 0x0804806a

        asm_instrs_str  = [(0x08048060, "mov eax,0x0", 5)]
        asm_instrs_str += [(0x08048065, "mov ebx,0xa", 5)]
        asm_instrs_str += [(0x0804806a, "add eax,0x1", 3)]
        asm_instrs_str += [(0x0804806d, "sub ebx,0x1", 3)]
        asm_instrs_str += [(0x08048070, "cmp ebx,0x0", 3)]
        asm_instrs_str += [(0x08048073, "jne 0x0804806a", 2)]

        asm_instrs = []

        for addr, asm, size in asm_instrs_str:
            asm_instr = self._asm_parser.parse(asm)
            asm_instr.address = addr
            asm_instr.size = size

            asm_instrs.append(asm_instr)

        reil_instrs = self.__translate(asm_instrs)

        regs_final, _ = self._emulator.execute(
            reil_instrs,
            start=0x08048060 << 8
        )

        self.assertEqual(regs_final["eax"], 0xa)
        self.assertEqual(regs_final["ebx"], 0x0)

    def test_mov(self):
        asm_instrs  = [self._asm_parser.parse("mov eax, 0xdeadbeef")]
        asm_instrs += [self._asm_parser.parse("mov al, 0x12")]
        asm_instrs += [self._asm_parser.parse("mov ah, 0x34")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs  = self._translator.translate(asm_instrs[0])
        reil_instrs += self._translator.translate(asm_instrs[1])
        reil_instrs += self._translator.translate(asm_instrs[2])

        regs_initial = {
            "eax" : 0xffffffff,
        }

        regs_final, _ = self._emulator.execute_lite(reil_instrs, context=regs_initial)

        self.assertEqual(regs_final["eax"], 0xdead3412)

    def test_pre_hanlder(self):
        def pre_hanlder(emulator, instruction, parameter):
            paramter.append(True)

        asm = ["mov eax, ebx"]

        x86_instrs = map(self._asm_parser.parse, asm)
        self.__set_address(0xdeadbeef, x86_instrs)
        reil_instrs = map(self._translator.translate, x86_instrs)

        paramter = []

        self._emulator.set_instruction_pre_handler(pre_hanlder, paramter)

        reil_ctx_out, reil_mem_out = self._emulator.execute_lite(
            reil_instrs[0]
        )

        self.assertTrue(len(paramter) > 0)

    def test_post_hanlder(self):
        def post_hanlder(emulator, instruction, parameter):
            paramter.append(True)

        asm = ["mov eax, ebx"]

        x86_instrs = map(self._asm_parser.parse, asm)
        self.__set_address(0xdeadbeef, x86_instrs)
        reil_instrs = map(self._translator.translate, x86_instrs)

        paramter = []

        self._emulator.set_instruction_post_handler(post_hanlder, paramter)

        reil_ctx_out, reil_mem_out = self._emulator.execute_lite(
            reil_instrs[0]
        )

        self.assertTrue(len(paramter) > 0)

    def test_zero_division_error_1(self):
        asm_instrs  = [self._asm_parser.parse("div ebx")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs  = self._translator.translate(asm_instrs[0])

        regs_initial = {
            "eax" : 0x2,
            "edx" : 0x2,
            "ebx" : 0x0,
        }

        self.assertRaises(ReilCpuZeroDivisionError, self._emulator.execute_lite, reil_instrs, context=regs_initial)

    def test_zero_division_error_2(self):
        instrs = ["mod [DWORD eax, DWORD ebx, DWORD t0]"]

        reil_instrs = self._reil_parser.parse(instrs)

        regs_initial = {
            "eax" : 0x2,
            "ebx" : 0x0,
        }

        self.assertRaises(ReilCpuZeroDivisionError, self._emulator.execute_lite, reil_instrs, context=regs_initial)

    def test_invalid_address_error_1(self):
        asm_instrs = [self._asm_parser.parse("jmp eax")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs = self.__translate(asm_instrs)

        regs_initial = {
            "eax" : 0xffffffff,
        }

        self.assertRaises(ReilCpuInvalidAddressError, self._emulator.execute, reil_instrs, start=0xdeadbeef << 8, registers=regs_initial)

    def test_invalid_address_error_2(self):
        asm_instrs = [self._asm_parser.parse("mov eax, 0xdeadbeef")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs = self.__translate(asm_instrs)

        regs_initial = {
            "eax" : 0xffffffff,
        }

        self.assertRaises(ReilCpuInvalidAddressError, self._emulator.execute, reil_instrs, start=0xdeadbef0 << 8, registers=regs_initial)

    # Auxiliary methods
    # ======================================================================== #
    def __set_address(self, address, asm_instrs):
        addr = address

        for asm_instr in asm_instrs:
            asm_instr.address = addr
            addr += 1

    def __translate(self, asm_instrs):
        instr_container = ReilContainer()

        asm_instr_last = None
        instr_seq_prev = None

        for asm_instr in asm_instrs:
            instr_seq = ReilSequence()

            for reil_instr in self._translator.translate(asm_instr):
                instr_seq.append(reil_instr)

            if instr_seq_prev:
                instr_seq_prev.next_sequence_address = instr_seq.address

            instr_container.add(instr_seq)

            instr_seq_prev = instr_seq

        if instr_seq_prev:
            if asm_instr_last:
                instr_seq_prev.next_sequence_address = (asm_instr_last.address + asm_instr_last.size) << 8

        # instr_container.dump()

        return instr_container
class ReilEmulatorTests(unittest.TestCase):

    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._emulator = ReilEmulator(self._arch_info.address_size)

        self._emulator.set_arch_registers(self._arch_info.registers_gp_all)
        self._emulator.set_arch_registers_size(self._arch_info.registers_size)
        self._emulator.set_reg_access_mapper(self._arch_info.alias_mapper)

        self._asm_parser = X86Parser()
        self._translator = X86Translator()

    def test_add(self):
        asm_instrs  = self._asm_parser.parse("add eax, ebx")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax" : 0x1,
            "ebx" : 0x2,
        }

        regs_final, _ = self._emulator.execute_lite(
            reil_instrs,
            context=regs_initial
        )

        self.assertEqual(regs_final["eax"], 0x3)
        self.assertEqual(regs_final["ebx"], 0x2)

    def test_loop(self):
        # 0x08048060 : b8 00 00 00 00   mov eax,0x0
        # 0x08048065 : bb 0a 00 00 00   mov ebx,0xa
        # 0x0804806a : 83 c0 01         add eax,0x1
        # 0x0804806d : 83 eb 01         sub ebx,0x1
        # 0x08048070 : 83 fb 00         cmp ebx,0x0
        # 0x08048073 : 75 f5            jne 0x0804806a

        asm_instrs_str  = [(0x08048060, "mov eax,0x0", 5)]
        asm_instrs_str += [(0x08048065, "mov ebx,0xa", 5)]
        asm_instrs_str += [(0x0804806a, "add eax,0x1", 3)]
        asm_instrs_str += [(0x0804806d, "sub ebx,0x1", 3)]
        asm_instrs_str += [(0x08048070, "cmp ebx,0x0", 3)]
        asm_instrs_str += [(0x08048073, "jne 0x0804806a", 2)]

        asm_instrs = []

        for addr, asm, size in asm_instrs_str:
            asm_instr = self._asm_parser.parse(asm)
            asm_instr.address = addr
            asm_instr.size = size

            asm_instrs.append(asm_instr)

        reil_instrs = [self._translator.translate(instr)
                        for instr in asm_instrs]

        regs_final, _ = self._emulator.execute(
            reil_instrs,
            0x08048060 << 8,
            context=[]
        )

        self.assertEqual(regs_final["eax"], 0xa)
        self.assertEqual(regs_final["ebx"], 0x0)

    def test_mov(self):
        asm_instrs  = [self._asm_parser.parse("mov eax, 0xdeadbeef")]
        asm_instrs += [self._asm_parser.parse("mov al, 0x12")]
        asm_instrs += [self._asm_parser.parse("mov ah, 0x34")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs  = self._translator.translate(asm_instrs[0])
        reil_instrs += self._translator.translate(asm_instrs[1])
        reil_instrs += self._translator.translate(asm_instrs[2])

        regs_initial = {
            "eax" : 0xffffffff,
        }

        regs_final, _ = self._emulator.execute_lite(reil_instrs, context=regs_initial)

        self.assertEqual(regs_final["eax"], 0xdead3412)

    def __set_address(self, address, asm_instrs):
        addr = address

        for asm_instr in asm_instrs:
            asm_instr.address = addr
            addr += 1
Exemple #23
0
class ReilEmulatorTaintTests(unittest.TestCase):
    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._emulator = ReilEmulator(self._arch_info)

        self._asm_parser = X86Parser()
        self._translator = X86Translator()

    def test_arithmetic(self):
        asm_instrs = self._asm_parser.parse("add eax, ebx")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax": 0x1,
            "ebx": 0x2,
        }

        self._emulator.set_register_taint("ebx", True)

        regs_final, _ = self._emulator.execute_lite(reil_instrs,
                                                    context=regs_initial)

        self.assertEqual(self._emulator.get_register_taint("eax"), True)

    def test_store_mem_1(self):
        asm_instrs = self._asm_parser.parse("mov [eax], ebx")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax": 0xcafecafe,
            "ebx": 0x2,
        }

        self._emulator.set_register_taint("ebx", True)

        regs_final, _ = self._emulator.execute_lite(reil_instrs,
                                                    context=regs_initial)

        self.assertEqual(
            self._emulator.get_memory_taint(regs_initial['eax'], 4), True)

    def test_store_mem_2(self):
        asm_instrs = self._asm_parser.parse("mov [eax], ebx")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax": 0xcafecafe,
            "ebx": 0x2,
        }

        self._emulator.set_register_taint("eax", True)

        regs_final, _ = self._emulator.execute_lite(reil_instrs,
                                                    context=regs_initial)

        self.assertEqual(
            self._emulator.get_memory_taint(regs_initial['eax'], 4), False)

    def test_load_mem_1(self):
        asm_instrs = self._asm_parser.parse("mov eax, [ebx]")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax": 0x2,
            "ebx": 0xcafecafe,
        }

        self._emulator.set_memory_taint(regs_initial["ebx"], 4, True)

        regs_final, _ = self._emulator.execute_lite(reil_instrs,
                                                    context=regs_initial)

        self.assertEqual(self._emulator.get_register_taint("eax"), True)

    def __set_address(self, address, asm_instrs):
        addr = address

        for asm_instr in asm_instrs:
            asm_instr.address = addr
            addr += 1
Exemple #24
0
class X86TranslationTests(unittest.TestCase):

    def setUp(self):
        self.trans_mode = FULL_TRANSLATION

        self.arch_mode = ARCH_X86_MODE_64

        self.arch_info = X86ArchitectureInformation(self.arch_mode)

        self.x86_parser = X86Parser(self.arch_mode)
        self.x86_translator = X86Translator(self.arch_mode, self.trans_mode)
        self.smt_solver = SmtSolver()
        self.smt_translator = SmtTranslator(self.smt_solver, self.arch_info.address_size)
        self.reil_emulator = ReilEmulator(self.arch_info.address_size)

        self.reil_emulator.set_arch_registers(self.arch_info.registers_gp)
        self.reil_emulator.set_arch_registers_size(self.arch_info.register_size)
        self.reil_emulator.set_reg_access_mapper(self.arch_info.register_access_mapper())

        self.smt_translator.set_reg_access_mapper(self.arch_info.register_access_mapper())
        self.smt_translator.set_arch_registers_size(self.arch_info.register_size)

    def test_lea(self):
        asm = ["lea eax, [ebx + 0x100]"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_cld(self):
        asm = ["cld"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_clc(self):
        asm = ["clc"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_nop(self):
        asm = ["nop"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_test(self):
        asm = ["test eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_not(self):
        asm = ["not eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_xor(self):
        asm = ["xor eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_or(self):
        asm = ["or eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_and(self):
        asm = ["and eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_cmp(self):
        asm = ["cmp eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_neg(self):
        asm = ["neg eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_dec(self):
        asm = ["dec eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_inc(self):
        asm = ["inc eax"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_div(self):
        asm = ["div ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = {
            'rax'    : 0x10,
            'rbx'    : 0x2,
            'rdx'    : 0x0,
            'rflags' : 0x202,
        }

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_imul(self):
        asm = ["imul eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_mul(self):
        asm = ["mul ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_sbb(self):
        asm = ["sbb eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_sub(self):
        asm = ["sub eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_adc(self):
        asm = ["adc eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_add(self):
        asm = ["add eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_xchg(self):
        asm = ["xchg eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_movzx(self):
        asm = ["movzx eax, bx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_mov(self):
        asm = ["mov eax, ebx"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_shr(self):
        asm = ["shr eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_shl(self):
        asm = ["shl eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_sal(self):
        asm = ["sal eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_sar(self):
        asm = ["sar eax, 3"]

        x86_instrs = map(self.x86_parser.parse, asm)
        x86_instrs[0].address = 0xdeadbeef

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def test_stc(self):
        asm = ["stc"]

        x86_instrs = map(self.x86_parser.parse, asm)
        x86_instrs[0].address = 0xdeadbeef

        reil_instrs = map(self.x86_translator.translate, x86_instrs)

        ctx_init = self.__init_context()

        x86_rv, x86_ctx_out = pyasmjit.execute("\n".join(asm), ctx_init)

        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        reil_ctx_out = self.__fix_reil_flags(reil_ctx_out, x86_ctx_out)

        self.assertTrue(self.__compare_contexts(
            ctx_init,
            x86_ctx_out,
            reil_ctx_out
        ))

    def __init_context(self):
        return {
            'rax'    : 0xa,
            'rbx'    : 0x2,
            'rcx'    : 0xb,
            'rdx'    : 0xc,
            'rdi'    : 0xd,
            'rsi'    : 0xe,
            'rflags' : 0x202,
        }

    def __compare_contexts(self, context_init, x86_context, reil_context):
        match = True

        fmt = "%s (x86)  : %s (%s)"

        mask = 2**64-1

        for reg in sorted(context_init.keys()):
            if ((2**64-1) & x86_context[reg]) != ((2**64-1) & reil_context[reg]):
                x86_value = x86_context[reg] & mask
                reil_value = reil_context[reg] & mask

                if reg in ['rflags', 'eflags']:
                    x86_flags_str = self.__print_flags(x86_context[reg])
                    reil_flags_str = self.__print_flags(reil_context[reg])

                    print ("%s (x86)  : %s (%s)" % (reg, hex(x86_value), x86_flags_str))
                    print ("%s (reil) : %s (%s)" % (reg, hex(reil_value), reil_flags_str))
                else:
                    print ("%s (x86)  : %s " % (reg, hex(x86_value)))
                    print ("%s (reil) : %s " % (reg, hex(reil_value)))

                match = False
                break

        if not match:
            self.__print_contexts(context_init, x86_context, reil_context)

        return match

    def __print_contexts(self, context_init, x86_context, reil_context):
        header_fmt = "{0:^8s} : {1:>16s} ?= {2:<16s}"
        header = header_fmt.format("Register", "x86", "REIL")
        ruler  = "-" * len(header)

        print(header)
        print(ruler)

        fmt = "{0:>8s} : {1:016x} {eq} {2:016x} ({1:>5d} {eq} {2:<5d}) {marker}"
        mask = 2**64-1

        for reg in sorted(context_init.keys()):
            if (x86_context[reg] & mask) != (reil_context[reg] & mask):
                eq = "!="
                marker = "<"
            else:
                eq = "=="
                marker = ""

            print fmt.format(
                reg,
                (2**64-1) & x86_context[reg],
                (2**64-1) & reil_context[reg],
                eq=eq,
                marker=marker
            )

    def __print_flags(self, flags_reg):
        # flags
        flags = {
             0 : "cf",  # bit 0
             2 : "pf",  # bit 2
             4 : "af",  # bit 4
             6 : "zf",  # bit 6
             7 : "sf",  # bit 7
            11 : "of",  # bit 11
            10 : "df",  # bit 10
        }

        out = ""

        for bit, flag in flags.items():
            if flags_reg & 2**bit:
                out += flag.upper() + " "
            else:
                out += flag.lower() + " "

        return out[:-1]

    def __fix_reil_flags(self, reil_context, x86_context):
        reil_context_out = dict(reil_context)

        flags_reg = 'eflags' if 'eflags' in reil_context_out else 'rflags'

        # Remove this when AF and PF are implemented.
        reil_context_out[flags_reg] |= (x86_context[flags_reg] & 2**4) # AF
        reil_context_out[flags_reg] |= (x86_context[flags_reg] & 2**2) # PF

        return reil_context_out
class ArmTranslationTests(unittest.TestCase):

    def setUp(self):
        self.trans_mode = FULL_TRANSLATION

        self.arch_mode = ARCH_ARM_MODE_32

        self.arch_info = ArmArchitectureInformation(self.arch_mode)

        self.arm_parser = ArmParser(self.arch_mode)
        self.arm_translator = ArmTranslator(self.arch_mode, self.trans_mode)
        self.reil_emulator = ReilEmulator(self.arch_info.address_size)

        self.reil_emulator.set_arch_registers(self.arch_info.registers_gp_all)
        self.reil_emulator.set_arch_registers_size(self.arch_info.registers_size)
        self.reil_emulator.set_reg_access_mapper(self.arch_info.alias_mapper)

        self.context_filename = "failing_context.data"

    def __init_context(self):
        """Initialize register with random values.
        """
        context = self.__create_random_context()

        return context

    def __create_random_context(self):
        context = {}

        for reg in self.arch_info.registers_gp_base:
            if reg not in ['r13', 'r14', 'r15']:
                min_value, max_value = 0, 2**self.arch_info.operand_size - 1
                context[reg] = random.randint(min_value, max_value)

        # Only highest 4 bits (N, C, Z, V) are randomized, the rest are left in the default (user mode) value
        context['apsr'] = 0x00000010
        context['apsr'] |= random.randint(0x0, 0xF) << 28

        return context

    def __load_failing_context(self):
        f = open(self.context_filename, "rb")
        context = pickle.load(f)
        f.close()

        return context

    def __save_failing_context(self, context):
        f = open(self.context_filename, "wb")
        pickle.dump(context, f)
        f.close()

    def __compare_contexts(self, context_init, arm_context, reil_context):
        match = True
        mask = 2**32-1

        for reg in sorted(context_init.keys()):
            if (arm_context[reg] & mask) != (reil_context[reg] & mask):
                match = False
                break

        return match

    def __print_contexts(self, context_init, arm_context, reil_context):
        out = "Contexts don't match!\n\n"

        header_fmt = " {0:^8s} : {1:^16s} | {2:>16s} ?= {3:<16s}\n"
        header = header_fmt.format("Register", "Initial", "ARM", "REIL")
        ruler = "-" * len(header) + "\n"

        out += header
        out += ruler

        fmt = " {0:>8s} : {1:08x} | {2:08x} {eq} {3:08x} {marker}\n"

        mask = 2**64-1

        for reg in sorted(context_init.keys()):
            if (arm_context[reg] & mask) != (reil_context[reg] & mask):
                eq, marker = "!=", "<"
            else:
                eq, marker = "==", ""

            out += fmt.format(
                reg,
                context_init[reg] & mask,
                arm_context[reg] & mask,
                reil_context[reg] & mask,
                eq=eq,
                marker=marker
            )

        # Pretty print flags.
        reg = "apsr"
        fmt = "{0:s} ({1:>4s}) : {2:08x} ({3:s})"

        arm_value = arm_context[reg] & mask
        reil_value = reil_context[reg] & mask

        if arm_value != reil_value:
            arm_flags_str = self.__print_flags(arm_context[reg])
            reil_flags_str = self.__print_flags(reil_context[reg])

            out += "\n"
            out += fmt.format(reg, "ARM", arm_value, arm_flags_str) + "\n"
            out += fmt.format(reg, "reil", reil_value, reil_flags_str)

        return out

    def __print_flags(self, flags_reg):
        # flags
        flags = {
             31 : "nf",  # bit 31
             30 : "zf",  # bit 30
             29 : "cf",  # bit 29
             28 : "vf",  # bit 28
        }

        out = ""

        for bit, flag in flags.items():
            flag_str = flag.upper() if flags_reg & 2**bit else flag.lower()
            out +=  flag_str + " "

        return out[:-1]

    def __fix_reil_flag(self, reil_context, arm_context, flag):
        reil_context_out = dict(reil_context)

        flags_reg = 'eflags' if 'eflags' in reil_context_out else 'rflags'

        arch_size = self.arch_info.architecture_size

        _, bit = self.arch_info.alias_mapper[flag]

        # Clean flag.
        reil_context_out[flags_reg] &= ~(2**bit) & (2**32-1)

        # Copy flag.
        reil_context_out[flags_reg] |= (arm_context[flags_reg] & 2**bit)

        return reil_context_out

    def __fix_reil_flags(self, reil_context, arm_context):
        # Remove this when AF and PF are implemented.
        reil_context_out = self.__fix_reil_flag(reil_context, arm_context, "af")
        reil_context_out = self.__fix_reil_flag(reil_context, arm_context, "pf")

        return reil_context_out

    def __set_address(self, address, arm_instrs):
        addr = address

        for arm_instr in arm_instrs:
            arm_instr.address = addr
            arm_instr.size = 4
            addr += 4

    def _test_asm_instruction(self, asm):
        print(asm)

        arm_instrs = map(self.arm_parser.parse, asm)

        self.__set_address(0xdeadbeef, arm_instrs)

        reil_instrs = map(self.arm_translator.translate, arm_instrs)

        ctx_init = self.__init_context()

        arm_rv, arm_ctx_out, _ = pyasmjit.arm_execute("\n".join(asm), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        cmp_result = self.__compare_contexts(ctx_init, arm_ctx_out, reil_ctx_out)

        if not cmp_result:
            self.__save_failing_context(ctx_init)

        self.assertTrue(cmp_result, self.__print_contexts(ctx_init, arm_ctx_out, reil_ctx_out))

    def _execute_asm(self, asm, ini_addr = 0x8000):
        print(asm)

        arm_instrs = map(self.arm_parser.parse, asm)

        self.__set_address(ini_addr, arm_instrs)

        reil_instrs = map(self.arm_translator.translate, arm_instrs)

        ctx_init = self.__init_context()

        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        return reil_ctx_out

    # R11 is used as a dirty register to check if the branch was taken or not
    def test_asm_branch_instruction(self):
        untouched_value = 0x45454545
        touched_value = 0x31313131

        inst_samples_touched = [
            ["mov r11, #0x{:x}".format(untouched_value),
               "b #0x800c",
               "mov r11, #0x{:x}".format(touched_value),
               "mov r0, r0",
            ],

            ["mov r11, #0x{:x}".format(untouched_value),
               "bx #0x800c",
               "mov r11, #0x{:x}".format(touched_value),
               "mov r0, r0",
            ],

            ["mov r11, #0x{:x}".format(untouched_value),
               "bl #0x800c",
               "mov r11, #0x{:x}".format(touched_value),
               "mov r0, r0",
            ],

            ["mov r11, #0x{:x}".format(untouched_value),
               "blx #0x800c",
               "mov r11, #0x{:x}".format(touched_value),
               "mov r0, r0",
            ],

            ["movs r11, #0x{:x}".format(untouched_value),
               "bne #0x800c",
               "mov r11, #0x{:x}".format(touched_value),
               "mov r0, r0",
            ],

            ["mov r11, #0x{:x}".format(untouched_value),
               "mov r1, #0x8010",
               "bx r1",
               "mov r11, #0x{:x}".format(touched_value),
               "mov r0, r0",
            ],

            ["mov r11, #0x{:x}".format(untouched_value),
               "mov r1, #0x8010",
               "blx r1",
               "mov r11, #0x{:x}".format(touched_value),
               "mov r0, r0",
            ],

        ]


        for asm in inst_samples_touched:
            reil_ctx_out = self._execute_asm(asm, 0x8000)
            self.assertTrue(reil_ctx_out['r11'] == untouched_value)
#             print reil_ctx_out

    # TODO: Merge with previous test function
    def _test_asm_instruction_with_mem(self, asm, reg_mem):
        print(asm)

        mem_dir = pyasmjit.arm_alloc(4096)

        arm_instrs = map(self.arm_parser.parse, asm)

        self.__set_address(0xdeadbeef, arm_instrs)

        reil_instrs = map(self.arm_translator.translate, arm_instrs)

        ctx_init = self.__init_context()

        ctx_init[reg_mem] = mem_dir

        arm_rv, arm_ctx_out, arm_mem_out = pyasmjit.arm_execute("\n".join(asm), ctx_init)

        self.reil_emulator._mem._memory = {} # TODO: Check how to clean emulator memory.

        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs,
            0xdeadbeef << 8,
            context=ctx_init
        )

        base_dir = mem_dir

        for idx, b in enumerate(struct.unpack("B" * len(arm_mem_out), arm_mem_out)):
            if (base_dir + idx) in reil_mem_out._memory: # TODO: Don't access variable directly.
                self.assertTrue(b == reil_mem_out._memory[base_dir + idx])
            else:
                self.assertTrue(b == 0x0) # Memory in pyasmjit is initialized to 0


        cmp_result = self.__compare_contexts(ctx_init, arm_ctx_out, reil_ctx_out)

        if not cmp_result:
            self.__save_failing_context(ctx_init)

        self.assertTrue(cmp_result, self.__print_contexts(ctx_init, arm_ctx_out, reil_ctx_out))

        pyasmjit.arm_free() # There is only one memory pool, so there is no need (for now) to specify the address


    # TODO: R13 (SP), R14 (LR), R15 (PC) are outside testing scope for now

    def test_data_proc_inst(self):
        inst_samples = [
            # No flags
            ["mov r0, r1"],
            ["mov r3, r8"],
            ["mov r5, r8"],
            ["and r0, r1, r2"],
            ["and r0, r6, #0x33"],
            ["orr r3, r5, r8"],
            ["orr r3, r5, #0x79"],
            ["orr r3, r5, r8, lsl #0x19"],
            ["eor r3, r5, r8"],
            ["eor r8, r4, r5, lsl r6"],
            ["eor r8, r4, r5, lsl #0x11"],
            ["add r8, r9, r11"],
            ["sub r0, r3, r12"],
            ["cmp r3, r12"],
            ["cmn r3, r12"],
            ['mov r8, r5, lsl r6'],
            ['eor r8, r4, r5, lsl r6'],
            ['mul r3, r4, r8'],
            ["mov r8, #0", 'mul r3, r4, r8'],
            ['mul r3, r4, r4'],

            # Flags update
            ["movs r0, #0"],
            ["movs r0, #-10"],
            ["mov r0, #0x7FFFFFFF", "mov r1, r0", "adds r3, r0, r1"],
            ["mov r0, #0x7FFFFFFF", "mov r1, r0", "subs r3, r0, r1"],
            ["mov r0, #0x00FFFFFF", "add r1, r0, #10", "subs r3, r0, r1"],
            ["mov r0, #0xFFFFFFFF", "adds r3, r0, #10"],
            ["mov r0, #0x7FFFFFFF", "mov r1,  #5", "adds r3, r0, r1"],
            ["mov r0, #0x80000000", "mov r1,  #5", "subs r3, r0, r1"],


            # Flags evaluation
            ["moveq r0, r1"],
            ["movne r3, r8"],
            ["movcs r5, r8"],
            ["andcc r0, r1, r2"],
            ["andmi r0, r6, #0x33"],
            ["orrpl r3, r5, r8"],
            ["orrvs r3, r5, #0x79"],
            ["orrvc r3, r5, r8, lsl #0x19"],
            ["eorhi r3, r5, r8"],
            ["eorls r8, r4, r5, lsl r6"],
            ["eorge r8, r4, r5, lsl #0x11"],
            ["addlt r8, r9, r11"],
            ["subgt r0, r3, r12"],
            ["cmple r3, r12"],
            ["cmnal r3, r12"],
            ["addhs r8, r9, r11"],
            ["sublo r0, r3, r12"],
        ]


        for i in inst_samples:
            self._test_asm_instruction(i)

    # R12 is loaded with the memory address
    def test_mem_inst(self):
        inst_samples = [
            ["str r0, [r12]", "ldr  r1, [r12]"],
            ["stm r12!, {r0 - r4}",   "ldmdb r12, {r5 - r9}"],
            ["stmia r12, {r8 - r9}",  "ldmia r12, {r1 - r2}"],
            ["stmib r12!, {r11}",    "ldmda r12!, {r3}"],
            ["add r12, r12, #0x100", "stmda r12, {r9 - r11}",   "ldmda r12, {r1 - r3}"],
            ["add r12, r12, #0x100", "stmdb r12!, {r3}",   "ldmia r12!, {r9}"],
            ["add r12, r12, #0x100", "stmfd r12, {r2 - r4}"],
            ["stmfa r12!, {r1 - r4}"],
            ["add r12, r12, #0x100", "stmed r12, {r6 - r9}"],
            ["stmea r12!, {r5 - r7}"],
            ["mov r0, r13", "mov r13, r12", "push {r1 - r10}", "pop {r2 - r11}", "mov r13, r0", "mov r0, #0"], # The last inst. is needed because the emulator has no access to the real value of native r13 (SP) which is not passed in the context
            ["mov r0, r13", "mov r13, r12", "push {r2 - r11}", "pop {r1 - r10}", "mov r13, r0", "mov r0, #0"],
        ]

        for i in inst_samples:
            self._test_asm_instruction_with_mem(i, 'r12')
class ReilEmulatorTaintTests(unittest.TestCase):

    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._emulator = ReilEmulator(self._arch_info)

        self._asm_parser = X86Parser()
        self._translator = X86Translator()

    def test_arithmetic(self):
        asm_instrs  = self._asm_parser.parse("add eax, ebx")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax" : 0x1,
            "ebx" : 0x2,
        }

        self._emulator.set_register_taint("ebx", True)

        regs_final, _ = self._emulator.execute_lite(
            reil_instrs,
            context=regs_initial
        )

        self.assertEqual(self._emulator.get_register_taint("eax"), True)

    def test_store_mem_1(self):
        asm_instrs  = self._asm_parser.parse("mov [eax], ebx")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax" : 0xcafecafe,
            "ebx" : 0x2,
        }

        self._emulator.set_register_taint("ebx", True)

        regs_final, _ = self._emulator.execute_lite(
            reil_instrs,
            context=regs_initial
        )

        self.assertEqual(self._emulator.get_memory_taint(regs_initial['eax'], 4), True)

    def test_store_mem_2(self):
        asm_instrs  = self._asm_parser.parse("mov [eax], ebx")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax" : 0xcafecafe,
            "ebx" : 0x2,
        }

        self._emulator.set_register_taint("eax", True)

        regs_final, _ = self._emulator.execute_lite(
            reil_instrs,
            context=regs_initial
        )

        self.assertEqual(self._emulator.get_memory_taint(regs_initial['eax'], 4), False)

    def test_load_mem_1(self):
        asm_instrs  = self._asm_parser.parse("mov eax, [ebx]")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax" : 0x2,
            "ebx" : 0xcafecafe,
        }

        self._emulator.set_memory_taint(regs_initial["ebx"], 4, True)

        regs_final, _ = self._emulator.execute_lite(
            reil_instrs,
            context=regs_initial
        )

        self.assertEqual(self._emulator.get_register_taint("eax"), True)

    def __set_address(self, address, asm_instrs):
        addr = address

        for asm_instr in asm_instrs:
            asm_instr.address = addr
            addr += 1
class ArmTranslationTests(unittest.TestCase):

    def setUp(self):
        self.trans_mode = FULL_TRANSLATION
        self.arch_mode = ARCH_ARM_MODE_THUMB
        self.arch_info = ArmArchitectureInformation(self.arch_mode)
        self.arm_parser = ArmParser(self.arch_mode)
        self.arm_translator = ArmTranslator(self.arch_mode, self.trans_mode)
        self.reil_emulator = ReilEmulator(self.arch_info)

        self.context_filename = "failing_context.data"

    def __init_context(self):
        """Initialize register with random values.
        """
        context = self.__create_random_context()

        return context

    def __create_random_context(self):
        context = {}

        for reg in self.arch_info.registers_gp_base:
            if reg not in ['r13', 'r14', 'r15']:
                min_value, max_value = 0, 2**self.arch_info.operand_size - 1
                context[reg] = random.randint(min_value, max_value)

        # Only highest 4 bits (N, C, Z, V) are randomized, the rest are
        # left in the default (user mode) value.
        context['apsr'] = 0x00000010
        context['apsr'] |= random.randint(0x0, 0xF) << 28

        return context

    def __load_failing_context(self):
        f = open(self.context_filename, "rb")
        context = pickle.load(f)
        f.close()

        return context

    def __save_failing_context(self, context):
        f = open(self.context_filename, "wb")
        pickle.dump(context, f)
        f.close()

    def __compare_contexts(self, context_init, arm_context, reil_context):
        match = True
        mask = 2**32-1

        for reg in sorted(context_init.keys()):
            if (arm_context[reg] & mask) != (reil_context[reg] & mask):
                match = False
                break

        return match

    def __print_contexts(self, context_init, arm_context, reil_context):
        out = "Contexts don't match!\n\n"

        header_fmt = " {0:^8s} : {1:^16s} | {2:>16s} ?= {3:<16s}\n"
        header = header_fmt.format("Register", "Initial", "ARM", "REIL")
        ruler = "-" * len(header) + "\n"

        out += header
        out += ruler

        fmt = " {0:>8s} : {1:08x} | {2:08x} {eq} {3:08x} {marker}\n"

        mask = 2**64-1

        for reg in sorted(context_init.keys()):
            if (arm_context[reg] & mask) != (reil_context[reg] & mask):
                eq, marker = "!=", "<"
            else:
                eq, marker = "==", ""

            out += fmt.format(
                reg,
                context_init[reg] & mask,
                arm_context[reg] & mask,
                reil_context[reg] & mask,
                eq=eq,
                marker=marker
            )

        # Pretty print flags.
        reg = "apsr"
        fmt = "{0:s} ({1:>4s}) : {2:08x} ({3:s})"

        arm_value = arm_context[reg] & mask
        reil_value = reil_context[reg] & mask

        if arm_value != reil_value:
            arm_flags_str = self.__print_flags(arm_context[reg])
            reil_flags_str = self.__print_flags(reil_context[reg])

            out += "\n"
            out += fmt.format(reg, "ARM", arm_value, arm_flags_str) + "\n"
            out += fmt.format(reg, "reil", reil_value, reil_flags_str)

        return out

    def __print_flags(self, flags_reg):
        # flags
        flags = {
             31 : "nf",  # bit 31
             30 : "zf",  # bit 30
             29 : "cf",  # bit 29
             28 : "vf",  # bit 28
        }

        out = ""

        for bit, flag in flags.items():
            flag_str = flag.upper() if flags_reg & 2**bit else flag.lower()
            out += flag_str + " "

        return out[:-1]

    def __set_address(self, address, arm_instrs):
        addr = address

        for arm_instr in arm_instrs:
            arm_instr.address = addr
            arm_instr.size = 4
            addr += 4

    def __translate(self, asm_instrs):
        instr_container = ReilContainer()

        asm_instr_last = None
        instr_seq_prev = None

        for asm_instr in asm_instrs:
            instr_seq = ReilSequence()

            for reil_instr in self.arm_translator.translate(asm_instr):
                instr_seq.append(reil_instr)

            if instr_seq_prev:
                instr_seq_prev.next_sequence_address = instr_seq.address

            instr_container.add(instr_seq)

            instr_seq_prev = instr_seq

        if instr_seq_prev:
            if asm_instr_last:
                instr_seq_prev.next_sequence_address = (asm_instr_last.address + asm_instr_last.size) << 8

        return instr_container

    def __asm_to_reil_use_parser(self, asm_list, address):
        # Using the parser:

        arm_instrs = [self.arm_parser.parse(asm) for asm in asm_list]

        self.__set_address(address, arm_instrs)

        reil_instrs = self.__translate(arm_instrs)

        return reil_instrs

    def __asm_to_reil(self, asm_list, address):
        # Using gcc:

        asm = "\n".join(asm_list)

        bytes = self._arm_compile(asm)

        curr_addr = 0
        end_addr = len(bytes)

        dis = ArmDisassembler();

        arm_instr_list = []

        while curr_addr < end_addr:
            # disassemble instruction
            start, end = curr_addr, min(curr_addr + 16, end_addr)

            USE_ARM = 0
            arm_instr = dis.disassemble(bytes[start:end], 0x8000, USE_ARM)

            if not arm_instr:
                raise Exception("Error in capstone disassembly")

            arm_instr_list.append(arm_instr)

            # update instruction pointer
            curr_addr += arm_instr.size


        # TODO: Separate parser tests vs CS->BARF translator
#         arm_instrs = [self.arm_parser.parse(asm) for asm in asm_list]

        self.__set_address(address, arm_instr_list)

        reil_instrs = self.__translate(arm_instr_list)

#         # DEBUG:
#         for reil_instr in reil_instrs:
#             print("  {}".format(reil_instr))

        return reil_instrs

    def _arm_compile(self, assembly):
        # TODO: This is a copy of the pyasmjit

        # Initialize return values
        rc = 0
        ctx = {}

        # Create temporary files for compilation.
        f_asm = tempfile.NamedTemporaryFile(delete=False)
        f_obj = tempfile.NamedTemporaryFile(delete=False)
        f_bin = tempfile.NamedTemporaryFile(delete=False)

        # Write assembly to a file.
        f_asm.write(assembly + '\n')  # TODO: -mthumb is not working (so the "code" directive is added)
        f_asm.close()

        # Run nasm.
        cmd_fmt = "gcc -c -x assembler {asm} -o {obj} -mthumb -march=armv7-a; objcopy -O binary {obj} {bin};"
        cmd = cmd_fmt.format(asm=f_asm.name, obj=f_obj.name, bin=f_bin.name)
        return_code = subprocess.call(cmd, shell=True)

        # Check for assembler errors.
        if return_code == 0:
            # Read binary code.
            binary = ""
            byte = f_bin.read(1)
            while byte:
                binary += byte
                byte = f_bin.read(1)
            f_bin.close()

        else:
            raise Exception("gcc error")

        # Remove temporary files.
        os.remove(f_asm.name)
        os.remove(f_obj.name)
        os.remove(f_bin.name)

        return binary


    def __run_code(self, asm_list, address, ctx_init):
        reil_instrs = self.__asm_to_reil(asm_list, address)

        _, arm_ctx_out, _ = pyasmjit.arm_execute("\n".join(asm_list), ctx_init)
        reil_ctx_out, _ = self.reil_emulator.execute(reil_instrs, start=address << 8, registers=ctx_init)

        return arm_ctx_out, reil_ctx_out

    def __test_asm_instruction(self, asm_list):
        ctx_init = self.__init_context()

        arm_ctx_out, reil_ctx_out = self.__run_code(asm_list, 0xdeadbeef, ctx_init)

        cmp_result = self.__compare_contexts(ctx_init, arm_ctx_out, reil_ctx_out)

        if not cmp_result:
            self.__save_failing_context(ctx_init)

        self.assertTrue(cmp_result, self.__print_contexts(ctx_init, arm_ctx_out, reil_ctx_out))

    def __test_asm_instruction_with_mem(self, asm_list, address_register):
        # TODO: Merge with previous test function.

        mem_addr = pyasmjit.arm_alloc(4096)

        self.reil_emulator.reset()

        reil_instrs = self.__asm_to_reil(asm_list, 0xdeadbeef)

        ctx_init = self.__init_context()
        ctx_init[address_register] = mem_addr

        _, arm_ctx_out, arm_mem_out = pyasmjit.arm_execute("\n".join(asm_list), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(reil_instrs, 0xdeadbeef << 8, registers=ctx_init)

        base_addr = mem_addr

        for idx, b in enumerate(struct.unpack("B" * len(arm_mem_out), arm_mem_out)):
            addr = base_addr + idx

            # TODO: Don't access variable directly.
            if addr in reil_mem_out._memory:
                self.assertTrue(b == reil_mem_out.read(addr, 1))
            else:
                # Memory in pyasmjit is initialized to 0.
                self.assertTrue(b == 0x0)

        cmp_result = self.__compare_contexts(ctx_init, arm_ctx_out, reil_ctx_out)

        if not cmp_result:
            self.__save_failing_context(ctx_init)

        self.assertTrue(cmp_result, self.__print_contexts(ctx_init, arm_ctx_out, reil_ctx_out))

        # NOTE: There is only one memory pool, so there is no need
        # (for now) to specify the address.
        pyasmjit.arm_free()

    def __execute_asm(self, asm_list, address=0x8000):
        reil_instrs = self.__asm_to_reil_use_parser(asm_list, address)

        ctx_init = self.__init_context()

        reil_ctx_out, _ = self.reil_emulator.execute(reil_instrs, address << 8, registers=ctx_init)

        return reil_ctx_out

    def test_data_instructions(self):
        # TODO: R13 (SP), R14 (LR), R15 (PC) are outside testing scope
        # for now.

        instr_samples = [
            # No flags
            ["mov r0, r1"],
            ["mov r3, r8"],
            ["mov r5, r8"],
            ["and r0, r1, r2"],
            ["and r0, r6, #0x33"],
            ["orr r3, r5, r8"],
            ["orr r3, r5, #0x79"],
            ["orr r3, r5, r8, lsl #0x19"],
            ["eor r3, r5, r8"],
            ["eor r8, r4, r5, lsl r6"],
            ["eor r8, r4, r5, lsl #0x11"],
            ["add r8, r9, r11"],
            ["sub r0, r3, r12"],
            ["rsb r0, r3, r12"],
            ["cmp r3, r12"],
            ["cmn r3, r12"],
            ["mov r8, r5, lsl r6"],
            ["eor r8, r4, r5, lsl r6"],
            ["mul r3, r4, r8"],
            ["mov r8, #0",
             "mul r3, r4, r8"],
            ["mul r3, r4, r4"],
            # ["movw r5, #0x1235"], # Not supported by the Raspberry Pi (ARMv6)
            ["mvn r3, r8"],
            # ["lsl r7, r2"], # TODO: Not implemented yet.
            ["lsl r2, r4, #0x0"],
            ["lsl r2, r4, #0x1"],
            ["lsl r2, r4, #10"],
            ["lsl r2, r4, #31"],


            # Flags update
            ["movs r0, #0"],
            ["movs r0, #-10"],
            ["mov r0, #0x7FFFFFFF",
             "mov r1, r0",
             "adds r3, r0, r1"],
            ["mov r0, #0x7FFFFFFF",
             "mov r1, r0",
             "subs r3, r0, r1"],
            ["mov r0, #0x00FFFFFF",
             "add r1, r0, #10",
             "subs r3, r0, r1"],
            ["mov r0, #0x00FFFFFF",
             "add r1, r0, #10",
             "rsbs r3, r0, r1"],
            ["mov r0, #0xFFFFFFFF",
             "adds r3, r0, #10"],
            ["mov r0, #0x7FFFFFFF",
             "mov r1,  #5",
             "adds r3, r0, r1"],
            ["mov r0, #0x80000000",
             "mov r1,  #5",
             "subs r3, r0, r1"],
            ["mov r0, #0x80000000",
             "mov r1,  #5",
             "rsbs r3, r0, r1"],
            ["lsls r2, r4, #0x0"],
            ["lsls r2, r4, #0x1"],
            ["lsls r2, r4, #10"],
            ["lsls r2, r4, #31"],

            # TODO: CHECK ReilCpuInvalidAddressError !!!!

            # Flags evaluation
#             ["moveq r0, r1"],
#             ["movne r3, r8"],
#             ["movcs r5, r8"],
#             ["andcc r0, r1, r2"],
#             ["andmi r0, r6, #0x33"],
#             ["orrpl r3, r5, r8"],
#             ["orrvs r3, r5, #0x79"],
#             ["orrvc r3, r5, r8, lsl #0x19"],
#             ["eorhi r3, r5, r8"],
#             ["eorls r8, r4, r5, lsl r6"],
#             ["eorge r8, r4, r5, lsl #0x11"],
#             ["addlt r8, r9, r11"],
#             ["subgt r0, r3, r12"],
#             # ["rsbgt r0, r3, r12"], #TODO: Check this after the AddressError Fix
#             ["cmple r3, r12"],
#             ["cmnal r3, r12"],
#             ["addhs r8, r9, r11"],
#             ["sublo r0, r3, r12"],
#             # ["rsblo r0, r3, r12"], #TODO: Check this after the AddressError Fix
        ]

        for instr in instr_samples:
            self.__test_asm_instruction(instr)

    def test_memory_instructions(self):
        # R12 is loaded with the memory address

        instr_samples = [
            ["str r0, [r12]",
             "ldr  r1, [r12]"],
            ["strb r0, [r12]",
             "ldrb  r1, [r12]"],
            ["strh r0, [r12]",
             "ldrh  r1, [r12]"],
            ["strd r6, [r12]",
             "ldrd  r2, [r12]"],
            ["strd r6, r7, [r12]",
             "ldrd  r8, r9, [r12]"],
            ["stm r12!, {r0 - r4}",
             "ldmdb r12, {r5 - r9}"],
            ["stmia r12, {r8 - r9}",
             "ldmia r12, {r1 - r2}"],
            ["stmib r12!, {r11}",
             "ldmda r12!, {r3}"],
            ["add r12, r12, #0x100",
             "stmda r12, {r9 - r11}",
             "ldmda r12, {r1 - r3}"],
            ["add r12, r12, #0x100",
             "stmdb r12!, {r3}",
             "ldmia r12!, {r9}"],
            ["add r12, r12, #0x100",
             "stmfd r12, {r2 - r4}"],
            ["stmfa r12!, {r1 - r4}"],
            ["add r12, r12, #0x100",
             "stmed r12, {r6 - r9}"],
            ["stmea r12!, {r5 - r7}"],
            # NOTE: The last instr. is needed because the emulator has
            # no access to the real value of native r13 (SP) which is
            # not passed in the context.
            ["mov r0, r13",
             "mov r13, r12",
             "push {r1 - r10}",
             "pop {r2 - r11}",
             "mov r13, r0",
             "mov r0, #0"],

            # TODO: Investigate sporadic seg fault in RPi
#             ["mov r0, r13",
#              "add r13, r12",
#              "push {r2 - r11}",
#              "pop {r1 - r10}",
#              "mov r13, r0",
#              "mov r0, #0"],
        ]

        for instr in instr_samples:
            self.__test_asm_instruction_with_mem(instr, 'r12')

    def test_branch_instructions(self):
        untouched_value = 0x45454545
        touched_value = 0x31313131

        # R11 is used as a dirty register to check if the branch was
        # taken or not.
        instr_samples = [
            ["mov r11, #0x{:x}".format(untouched_value),
             "b #0x800c",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
            ["mov r11, #0x{:x}".format(untouched_value),
             "bx #0x800c",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
            ["mov r11, #0x{:x}".format(untouched_value),
             "bl #0x800c",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
            ["mov r11, #0x{:x}".format(untouched_value),
             "blx #0x800c",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
            ["movs r11, #0x{:x}".format(untouched_value),
             "bne #0x800c",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
            ["mov r11, #0x{:x}".format(untouched_value),
             "mov r1, #0x8010",
             "bx r1",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
            ["mov r11, #0x{:x}".format(untouched_value),
             "mov r1, #0x8010",
             "blx r1",
             "mov r11, #0x{:x}".format(touched_value),
             "mov r0, r0",
            ],
        ]

        for instr in instr_samples:
            reil_ctx_out = self.__execute_asm(instr, 0x8000)

            self.assertTrue(reil_ctx_out['r11'] == untouched_value)
class ArmGadgetClassifierTests(unittest.TestCase):

    def setUp(self):

        self._arch_info = ArmArchitectureInformation(ARCH_ARM_MODE_32)
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver, self._arch_info.address_size)
        self._ir_emulator = ReilEmulator(self._arch_info.address_size)

        self._ir_emulator.set_arch_registers(self._arch_info.registers_gp_all)
        self._ir_emulator.set_arch_registers_size(self._arch_info.registers_size)
        self._ir_emulator.set_reg_access_mapper(self._arch_info.alias_mapper)

        self._smt_translator.set_reg_access_mapper(self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(self._arch_info.registers_size)

        self._code_analyzer = CodeAnalyzer(self._smt_solver, self._smt_translator)

        self._g_classifier = GadgetClassifier(self._ir_emulator, self._arch_info)
        self._g_verifier = GadgetVerifier(self._code_analyzer, self._arch_info)

    def _find_and_classify_gadgets(self, binary):
        g_finder = GadgetFinder(ArmDisassembler(), binary, ArmTranslator(translation_mode=LITE_TRANSLATION), ARCH_ARM, ARCH_ARM_MODE_32)

        g_candidates = g_finder.find(0x00000000, len(binary), instrs_depth=4)
        g_classified = self._g_classifier.classify(g_candidates[0])

#         Debug:
#         self._print_candidates(g_candidates)
#         self._print_classified(g_classified)

        return g_candidates, g_classified

    def test_move_register_1(self):
        # testing : dst_reg <- src_reg
        binary  = "\x04\x00\xa0\xe1"                     # 0x00 : (4)  mov    r0, r4
        binary += "\x31\xff\x2f\xe1"                     # 0x04 : (4)  blx    r1

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[0].type, GadgetType.MoveRegister)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r4", 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 1)
        self.assertTrue(ReilRegisterOperand("r14", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_move_register_2(self):
        # testing : dst_reg <- src_reg
        binary  = "\x00\x00\x84\xe2"                     # 0x00 : (4)  add    r0, r4, #0
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr


        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.MoveRegister)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r4", 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    # TODO: test_move_register_n: mul r0, r4, #1

    def test_load_constant_1(self):
        # testing : dst_reg <- constant
        binary  = "\x0a\x20\xa0\xe3"                     # 0x00 : (4)  mov    r2, #10
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources, [ReilImmediateOperand(10, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_2(self):
        # testing : dst_reg <- constant
        binary  = "\x02\x20\x42\xe0"                     # 0x00 : (4)  sub    r2, r2, r2
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources, [ReilImmediateOperand(0, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_3(self):
        # testing : dst_reg <- constant
        binary  = "\x02\x20\x22\xe0"                     # 0x00 : (4)  eor    r2, r2, r2
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources, [ReilImmediateOperand(0, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_4(self):
        # testing : dst_reg <- constant
        binary  = "\x00\x20\x02\xe2"                     # 0x00 : (4)  and    r2, r2, #0
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources, [ReilImmediateOperand(0, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_5(self):
        # testing : dst_reg <- constant
        binary  = "\x00\x20\x02\xe2"                     # and    r2, r2, #0
        binary += "\x21\x20\x82\xe3"                     # orr    r2, r2, #33
        binary += "\x1e\xff\x2f\xe1"                     # bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEquals(g_classified[0].sources, [ReilImmediateOperand(33, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_add_1(self):
        # testing : dst_reg <- src1_reg + src2_reg
        binary  = "\x08\x00\x84\xe0"                     # 0x00 : (4)  add    r0, r4, r8
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.Arithmetic)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilRegisterOperand("r8", 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[0].operation, "+")

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_sub_1(self):
        # testing : dst_reg <- src1_reg + src2_reg
        binary  = "\x08\x00\x44\xe0"                     # 0x00 : (4)  sub    r0, r4, r8
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.Arithmetic)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilRegisterOperand("r8", 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[0].operation, "-")

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_memory_1(self):
        # testing : dst_reg <- m[src_reg]
        binary  = "\x00\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4]
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadMemory)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r3", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_memory_2(self):
        # testing : dst_reg <- m[src_reg + offset]
        binary  = "\x33\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4 + 0x33]
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.LoadMemory)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x33, 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r3", 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    # TODO: ARM's ldr rd, [rn, r2] is not a valid classification right now

    def test_store_memory_1(self):
        # testing : dst_reg <- m[src_reg]
        binary  = "\x00\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4]
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.StoreMemory)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r3", 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_store_memory_2(self):
        # testing : dst_reg <- m[src_reg + offset]
        binary  = "\x33\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4 + 0x33]
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 1)

        self.assertEquals(g_classified[0].type, GadgetType.StoreMemory)
        self.assertEquals(g_classified[0].sources, [ReilRegisterOperand("r3", 32)])
        self.assertEquals(g_classified[0].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x33, 32)])

        self.assertEquals(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_load_add_1(self):
        # testing : dst_reg <- dst_reg + mem[src_reg]
        binary  = "\x00\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4]
        binary += "\x03\x00\x80\xe0"                     # 0x00 : (4)  add    r0, r0, r3
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticLoad)
        self.assertEquals(g_classified[1].sources, [ReilRegisterOperand("r0", 32), ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])
        self.assertEquals(g_classified[1].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r0", 32) in g_classified[1].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def test_arithmetic_load_add_2(self):
        # testing : dst_reg <- dst_reg + mem[src_reg + offset]
        binary  = "\x22\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4, 0x22]
        binary += "\x03\x00\x80\xe0"                     # 0x00 : (4)  add    r0, r0, r3
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticLoad)
        self.assertEquals(g_classified[1].sources, [ReilRegisterOperand("r0", 32), ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x22, 32)])
        self.assertEquals(g_classified[1].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r0", 32) in g_classified[1].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def test_arithmetic_store_add_1(self):
        # testing : m[dst_reg] <- m[dst_reg] + src_reg
        binary  = "\x00\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4]
        binary += "\x03\x30\x80\xe0"                     # 0x00 : (4)  add    r3, r0, r3
        binary += "\x00\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4]
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticStore)
        self.assertEquals(g_classified[1].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32), ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[1].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r4", 32) in g_classified[1].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def test_arithmetic_store_add_2(self):
        # testing : dst_reg <- dst_reg + mem[src_reg + offset]
        binary  = "\x22\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4, 0x22]
        binary += "\x03\x30\x80\xe0"                     # 0x00 : (4)  add    r3, r0, r3
        binary += "\x22\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4, 0x22]
        binary += "\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEquals(len(g_candidates), 1)
        self.assertEquals(len(g_classified), 2)

        self.assertEquals(g_classified[1].type, GadgetType.ArithmeticStore)
        self.assertEquals(g_classified[1].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x22, 32), ReilRegisterOperand("r0", 32)])
        self.assertEquals(g_classified[1].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x22, 32)])
        self.assertEquals(g_classified[1].operation, "+")

        self.assertEquals(len(g_classified[1].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r4", 32) in g_classified[1].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[1].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[1]))

    def _print_candidates(self, candidates):
        print "Candidates :"

        for gadget in candidates:
            print gadget
            print "-" * 10

    def _print_classified(self, classified):
        print "Classified :"

        for gadget in classified:
            print gadget
            print gadget.type
            print "-" * 10
Exemple #29
0
class ReilEmulatorTests(unittest.TestCase):
    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._emulator = ReilEmulator(self._arch_info.address_size)

        self._emulator.set_arch_registers(self._arch_info.registers_gp_all)
        self._emulator.set_arch_registers_size(self._arch_info.registers_size)
        self._emulator.set_reg_access_mapper(self._arch_info.alias_mapper)

        self._asm_parser = X86Parser()
        self._translator = X86Translator()

    def test_add(self):
        asm_instrs = self._asm_parser.parse("add eax, ebx")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax": 0x1,
            "ebx": 0x2,
        }

        regs_final, _ = self._emulator.execute_lite(reil_instrs,
                                                    context=regs_initial)

        self.assertEqual(regs_final["eax"], 0x3)
        self.assertEqual(regs_final["ebx"], 0x2)

    def test_loop(self):
        # 0x08048060 : b8 00 00 00 00   mov eax,0x0
        # 0x08048065 : bb 0a 00 00 00   mov ebx,0xa
        # 0x0804806a : 83 c0 01         add eax,0x1
        # 0x0804806d : 83 eb 01         sub ebx,0x1
        # 0x08048070 : 83 fb 00         cmp ebx,0x0
        # 0x08048073 : 75 f5            jne 0x0804806a

        asm_instrs_str = [(0x08048060, "mov eax,0x0", 5)]
        asm_instrs_str += [(0x08048065, "mov ebx,0xa", 5)]
        asm_instrs_str += [(0x0804806a, "add eax,0x1", 3)]
        asm_instrs_str += [(0x0804806d, "sub ebx,0x1", 3)]
        asm_instrs_str += [(0x08048070, "cmp ebx,0x0", 3)]
        asm_instrs_str += [(0x08048073, "jne 0x0804806a", 2)]

        asm_instrs = []

        for addr, asm, size in asm_instrs_str:
            asm_instr = self._asm_parser.parse(asm)
            asm_instr.address = addr
            asm_instr.size = size

            asm_instrs.append(asm_instr)

        reil_instrs = [
            self._translator.translate(instr) for instr in asm_instrs
        ]

        regs_final, _ = self._emulator.execute(reil_instrs,
                                               0x08048060 << 8,
                                               context=[])

        self.assertEqual(regs_final["eax"], 0xa)
        self.assertEqual(regs_final["ebx"], 0x0)

    def test_mov(self):
        asm_instrs = [self._asm_parser.parse("mov eax, 0xdeadbeef")]
        asm_instrs += [self._asm_parser.parse("mov al, 0x12")]
        asm_instrs += [self._asm_parser.parse("mov ah, 0x34")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs = self._translator.translate(asm_instrs[0])
        reil_instrs += self._translator.translate(asm_instrs[1])
        reil_instrs += self._translator.translate(asm_instrs[2])

        regs_initial = {
            "eax": 0xffffffff,
        }

        regs_final, _ = self._emulator.execute_lite(reil_instrs,
                                                    context=regs_initial)

        self.assertEqual(regs_final["eax"], 0xdead3412)

    def __set_address(self, address, asm_instrs):
        addr = address

        for asm_instr in asm_instrs:
            asm_instr.address = addr
            addr += 1
class ArmTranslationTests(unittest.TestCase):
    def setUp(self):
        self.trans_mode = FULL_TRANSLATION
        self.arch_mode = ARCH_ARM_MODE_THUMB
        self.arch_info = ArmArchitectureInformation(self.arch_mode)
        self.arm_parser = ArmParser(self.arch_mode)
        self.arm_translator = ArmTranslator(self.arch_mode, self.trans_mode)
        self.reil_emulator = ReilEmulator(self.arch_info)

        self.context_filename = "failing_context.data"

    def __init_context(self):
        """Initialize register with random values.
        """
        context = self.__create_random_context()

        return context

    def __create_random_context(self):
        context = {}

        for reg in self.arch_info.registers_gp_base:
            if reg not in ['r13', 'r14', 'r15']:
                min_value, max_value = 0, 2**self.arch_info.operand_size - 1
                context[reg] = random.randint(min_value, max_value)

        # Only highest 4 bits (N, C, Z, V) are randomized, the rest are
        # left in the default (user mode) value.
        context['apsr'] = 0x00000010
        context['apsr'] |= random.randint(0x0, 0xF) << 28

        return context

    def __load_failing_context(self):
        f = open(self.context_filename, "rb")
        context = pickle.load(f)
        f.close()

        return context

    def __save_failing_context(self, context):
        f = open(self.context_filename, "wb")
        pickle.dump(context, f)
        f.close()

    def __compare_contexts(self, context_init, arm_context, reil_context):
        match = True
        mask = 2**32 - 1

        for reg in sorted(context_init.keys()):
            if (arm_context[reg] & mask) != (reil_context[reg] & mask):
                match = False
                break

        return match

    def __print_contexts(self, context_init, arm_context, reil_context):
        out = "Contexts don't match!\n\n"

        header_fmt = " {0:^8s} : {1:^16s} | {2:>16s} ?= {3:<16s}\n"
        header = header_fmt.format("Register", "Initial", "ARM", "REIL")
        ruler = "-" * len(header) + "\n"

        out += header
        out += ruler

        fmt = " {0:>8s} : {1:08x} | {2:08x} {eq} {3:08x} {marker}\n"

        mask = 2**64 - 1

        for reg in sorted(context_init.keys()):
            if (arm_context[reg] & mask) != (reil_context[reg] & mask):
                eq, marker = "!=", "<"
            else:
                eq, marker = "==", ""

            out += fmt.format(reg,
                              context_init[reg] & mask,
                              arm_context[reg] & mask,
                              reil_context[reg] & mask,
                              eq=eq,
                              marker=marker)

        # Pretty print flags.
        reg = "apsr"
        fmt = "{0:s} ({1:>4s}) : {2:08x} ({3:s})"

        arm_value = arm_context[reg] & mask
        reil_value = reil_context[reg] & mask

        if arm_value != reil_value:
            arm_flags_str = self.__print_flags(arm_context[reg])
            reil_flags_str = self.__print_flags(reil_context[reg])

            out += "\n"
            out += fmt.format(reg, "ARM", arm_value, arm_flags_str) + "\n"
            out += fmt.format(reg, "reil", reil_value, reil_flags_str)

        return out

    def __print_flags(self, flags_reg):
        # flags
        flags = {
            31: "nf",  # bit 31
            30: "zf",  # bit 30
            29: "cf",  # bit 29
            28: "vf",  # bit 28
        }

        out = ""

        for bit, flag in flags.items():
            flag_str = flag.upper() if flags_reg & 2**bit else flag.lower()
            out += flag_str + " "

        return out[:-1]

    def __set_address(self, address, arm_instrs):
        addr = address

        for arm_instr in arm_instrs:
            arm_instr.address = addr
            arm_instr.size = 4
            addr += 4

    def __translate(self, asm_instrs):
        instr_container = ReilContainer()

        asm_instr_last = None
        instr_seq_prev = None

        for asm_instr in asm_instrs:
            instr_seq = ReilSequence()

            for reil_instr in self.arm_translator.translate(asm_instr):
                instr_seq.append(reil_instr)

            if instr_seq_prev:
                instr_seq_prev.next_sequence_address = instr_seq.address

            instr_container.add(instr_seq)

            instr_seq_prev = instr_seq

        if instr_seq_prev:
            if asm_instr_last:
                instr_seq_prev.next_sequence_address = (
                    asm_instr_last.address + asm_instr_last.size) << 8

        return instr_container

    def __asm_to_reil_use_parser(self, asm_list, address):
        # Using the parser:

        arm_instrs = [self.arm_parser.parse(asm) for asm in asm_list]

        self.__set_address(address, arm_instrs)

        reil_instrs = self.__translate(arm_instrs)

        return reil_instrs

    def __asm_to_reil(self, asm_list, address):
        # Using gcc:

        asm = "\n".join(asm_list)

        bytes = self._arm_compile(asm)

        curr_addr = 0
        end_addr = len(bytes)

        dis = ArmDisassembler()

        arm_instr_list = []

        while curr_addr < end_addr:
            # disassemble instruction
            start, end = curr_addr, min(curr_addr + 16, end_addr)

            USE_ARM = 0
            arm_instr = dis.disassemble(bytes[start:end], 0x8000, USE_ARM)

            if not arm_instr:
                raise Exception("Error in capstone disassembly")

            arm_instr_list.append(arm_instr)

            # update instruction pointer
            curr_addr += arm_instr.size

        # TODO: Separate parser tests vs CS->BARF translator


#         arm_instrs = [self.arm_parser.parse(asm) for asm in asm_list]

        self.__set_address(address, arm_instr_list)

        reil_instrs = self.__translate(arm_instr_list)

        #         # DEBUG:
        #         for reil_instr in reil_instrs:
        #             print("  {}".format(reil_instr))

        return reil_instrs

    def _arm_compile(self, assembly):
        # TODO: This is a copy of the pyasmjit

        # Initialize return values
        rc = 0
        ctx = {}

        # Create temporary files for compilation.
        f_asm = tempfile.NamedTemporaryFile(delete=False)
        f_obj = tempfile.NamedTemporaryFile(delete=False)
        f_bin = tempfile.NamedTemporaryFile(delete=False)

        # Write assembly to a file.
        f_asm.write(
            assembly + '\n'
        )  # TODO: -mthumb is not working (so the "code" directive is added)
        f_asm.close()

        # Run nasm.
        cmd_fmt = "gcc -c -x assembler {asm} -o {obj} -mthumb -march=armv7-a; objcopy -O binary {obj} {bin};"
        cmd = cmd_fmt.format(asm=f_asm.name, obj=f_obj.name, bin=f_bin.name)
        return_code = subprocess.call(cmd, shell=True)

        # Check for assembler errors.
        if return_code == 0:
            # Read binary code.
            binary = ""
            byte = f_bin.read(1)
            while byte:
                binary += byte
                byte = f_bin.read(1)
            f_bin.close()

        else:
            raise Exception("gcc error")

        # Remove temporary files.
        os.remove(f_asm.name)
        os.remove(f_obj.name)
        os.remove(f_bin.name)

        return binary

    def __run_code(self, asm_list, address, ctx_init):
        reil_instrs = self.__asm_to_reil(asm_list, address)

        _, arm_ctx_out, _ = pyasmjit.arm_execute("\n".join(asm_list), ctx_init)
        reil_ctx_out, _ = self.reil_emulator.execute(reil_instrs,
                                                     start=address << 8,
                                                     registers=ctx_init)

        return arm_ctx_out, reil_ctx_out

    def __test_asm_instruction(self, asm_list):
        ctx_init = self.__init_context()

        arm_ctx_out, reil_ctx_out = self.__run_code(asm_list, 0xdeadbeef,
                                                    ctx_init)

        cmp_result = self.__compare_contexts(ctx_init, arm_ctx_out,
                                             reil_ctx_out)

        if not cmp_result:
            self.__save_failing_context(ctx_init)

        self.assertTrue(
            cmp_result,
            self.__print_contexts(ctx_init, arm_ctx_out, reil_ctx_out))

    def __test_asm_instruction_with_mem(self, asm_list, address_register):
        # TODO: Merge with previous test function.

        mem_addr = pyasmjit.arm_alloc(4096)

        self.reil_emulator.reset()

        reil_instrs = self.__asm_to_reil(asm_list, 0xdeadbeef)

        ctx_init = self.__init_context()
        ctx_init[address_register] = mem_addr

        _, arm_ctx_out, arm_mem_out = pyasmjit.arm_execute(
            "\n".join(asm_list), ctx_init)
        reil_ctx_out, reil_mem_out = self.reil_emulator.execute(
            reil_instrs, 0xdeadbeef << 8, registers=ctx_init)

        base_addr = mem_addr

        for idx, b in enumerate(
                struct.unpack("B" * len(arm_mem_out), arm_mem_out)):
            addr = base_addr + idx

            # TODO: Don't access variable directly.
            if addr in reil_mem_out._memory:
                self.assertTrue(b == reil_mem_out.read(addr, 1))
            else:
                # Memory in pyasmjit is initialized to 0.
                self.assertTrue(b == 0x0)

        cmp_result = self.__compare_contexts(ctx_init, arm_ctx_out,
                                             reil_ctx_out)

        if not cmp_result:
            self.__save_failing_context(ctx_init)

        self.assertTrue(
            cmp_result,
            self.__print_contexts(ctx_init, arm_ctx_out, reil_ctx_out))

        # NOTE: There is only one memory pool, so there is no need
        # (for now) to specify the address.
        pyasmjit.arm_free()

    def __execute_asm(self, asm_list, address=0x8000):
        reil_instrs = self.__asm_to_reil_use_parser(asm_list, address)

        ctx_init = self.__init_context()

        reil_ctx_out, _ = self.reil_emulator.execute(reil_instrs,
                                                     address << 8,
                                                     registers=ctx_init)

        return reil_ctx_out

    def test_data_instructions(self):
        # TODO: R13 (SP), R14 (LR), R15 (PC) are outside testing scope
        # for now.

        instr_samples = [
            # No flags
            ["mov r0, r1"],
            ["mov r3, r8"],
            ["mov r5, r8"],
            ["and r0, r1, r2"],
            ["and r0, r6, #0x33"],
            ["orr r3, r5, r8"],
            ["orr r3, r5, #0x79"],
            ["orr r3, r5, r8, lsl #0x19"],
            ["eor r3, r5, r8"],
            ["eor r8, r4, r5, lsl r6"],
            ["eor r8, r4, r5, lsl #0x11"],
            ["add r8, r9, r11"],
            ["sub r0, r3, r12"],
            ["rsb r0, r3, r12"],
            ["cmp r3, r12"],
            ["cmn r3, r12"],
            ["mov r8, r5, lsl r6"],
            ["eor r8, r4, r5, lsl r6"],
            ["mul r3, r4, r8"],
            ["mov r8, #0", "mul r3, r4, r8"],
            ["mul r3, r4, r4"],
            # ["movw r5, #0x1235"], # Not supported by the Raspberry Pi (ARMv6)
            ["mvn r3, r8"],
            # ["lsl r7, r2"], # TODO: Not implemented yet.
            ["lsl r2, r4, #0x0"],
            ["lsl r2, r4, #0x1"],
            ["lsl r2, r4, #10"],
            ["lsl r2, r4, #31"],

            # Flags update
            ["movs r0, #0"],
            ["movs r0, #-10"],
            ["mov r0, #0x7FFFFFFF", "mov r1, r0", "adds r3, r0, r1"],
            ["mov r0, #0x7FFFFFFF", "mov r1, r0", "subs r3, r0, r1"],
            ["mov r0, #0x00FFFFFF", "add r1, r0, #10", "subs r3, r0, r1"],
            ["mov r0, #0x00FFFFFF", "add r1, r0, #10", "rsbs r3, r0, r1"],
            ["mov r0, #0xFFFFFFFF", "adds r3, r0, #10"],
            ["mov r0, #0x7FFFFFFF", "mov r1,  #5", "adds r3, r0, r1"],
            ["mov r0, #0x80000000", "mov r1,  #5", "subs r3, r0, r1"],
            ["mov r0, #0x80000000", "mov r1,  #5", "rsbs r3, r0, r1"],
            ["lsls r2, r4, #0x0"],
            ["lsls r2, r4, #0x1"],
            ["lsls r2, r4, #10"],
            ["lsls r2, r4, #31"],

            # TODO: CHECK ReilCpuInvalidAddressError !!!!

            # Flags evaluation
            #             ["moveq r0, r1"],
            #             ["movne r3, r8"],
            #             ["movcs r5, r8"],
            #             ["andcc r0, r1, r2"],
            #             ["andmi r0, r6, #0x33"],
            #             ["orrpl r3, r5, r8"],
            #             ["orrvs r3, r5, #0x79"],
            #             ["orrvc r3, r5, r8, lsl #0x19"],
            #             ["eorhi r3, r5, r8"],
            #             ["eorls r8, r4, r5, lsl r6"],
            #             ["eorge r8, r4, r5, lsl #0x11"],
            #             ["addlt r8, r9, r11"],
            #             ["subgt r0, r3, r12"],
            #             # ["rsbgt r0, r3, r12"], #TODO: Check this after the AddressError Fix
            #             ["cmple r3, r12"],
            #             ["cmnal r3, r12"],
            #             ["addhs r8, r9, r11"],
            #             ["sublo r0, r3, r12"],
            #             # ["rsblo r0, r3, r12"], #TODO: Check this after the AddressError Fix
        ]

        for instr in instr_samples:
            self.__test_asm_instruction(instr)

    def test_memory_instructions(self):
        # R12 is loaded with the memory address

        instr_samples = [
            ["str r0, [r12]", "ldr  r1, [r12]"],
            ["strb r0, [r12]", "ldrb  r1, [r12]"],
            ["strh r0, [r12]", "ldrh  r1, [r12]"],
            ["strd r6, [r12]", "ldrd  r2, [r12]"],
            ["strd r6, r7, [r12]", "ldrd  r8, r9, [r12]"],
            ["stm r12!, {r0 - r4}", "ldmdb r12, {r5 - r9}"],
            ["stmia r12, {r8 - r9}", "ldmia r12, {r1 - r2}"],
            ["stmib r12!, {r11}", "ldmda r12!, {r3}"],
            [
                "add r12, r12, #0x100", "stmda r12, {r9 - r11}",
                "ldmda r12, {r1 - r3}"
            ],
            ["add r12, r12, #0x100", "stmdb r12!, {r3}", "ldmia r12!, {r9}"],
            ["add r12, r12, #0x100", "stmfd r12, {r2 - r4}"],
            ["stmfa r12!, {r1 - r4}"],
            ["add r12, r12, #0x100", "stmed r12, {r6 - r9}"],
            ["stmea r12!, {r5 - r7}"],
            # NOTE: The last instr. is needed because the emulator has
            # no access to the real value of native r13 (SP) which is
            # not passed in the context.
            [
                "mov r0, r13", "mov r13, r12", "push {r1 - r10}",
                "pop {r2 - r11}", "mov r13, r0", "mov r0, #0"
            ],

            # TODO: Investigate sporadic seg fault in RPi
            #             ["mov r0, r13",
            #              "add r13, r12",
            #              "push {r2 - r11}",
            #              "pop {r1 - r10}",
            #              "mov r13, r0",
            #              "mov r0, #0"],
        ]

        for instr in instr_samples:
            self.__test_asm_instruction_with_mem(instr, 'r12')

    def test_branch_instructions(self):
        untouched_value = 0x45454545
        touched_value = 0x31313131

        # R11 is used as a dirty register to check if the branch was
        # taken or not.
        instr_samples = [
            [
                "mov r11, #0x{:x}".format(untouched_value),
                "b #0x800c",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
            [
                "mov r11, #0x{:x}".format(untouched_value),
                "bx #0x800c",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
            [
                "mov r11, #0x{:x}".format(untouched_value),
                "bl #0x800c",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
            [
                "mov r11, #0x{:x}".format(untouched_value),
                "blx #0x800c",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
            [
                "movs r11, #0x{:x}".format(untouched_value),
                "bne #0x800c",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
            [
                "mov r11, #0x{:x}".format(untouched_value),
                "mov r1, #0x8010",
                "bx r1",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
            [
                "mov r11, #0x{:x}".format(untouched_value),
                "mov r1, #0x8010",
                "blx r1",
                "mov r11, #0x{:x}".format(touched_value),
                "mov r0, r0",
            ],
        ]

        for instr in instr_samples:
            reil_ctx_out = self.__execute_asm(instr, 0x8000)

            self.assertTrue(reil_ctx_out['r11'] == untouched_value)