class CodeAnalyzerTests(unittest.TestCase):
    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)
        self._operand_size = self._arch_info.operand_size
        self._memory = Memory()
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver,
                                             self._operand_size)
        self._smt_translator.set_arch_alias_mapper(
            self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(
            self._arch_info.registers_size)
        self._disasm = X86Disassembler()
        self._ir_translator = X86Translator()
        self._bb_builder = CFGRecoverer(
            RecursiveDescent(self._disasm, self._memory, self._ir_translator,
                             self._arch_info))

    def test_check_path_satisfiability(self):
        if VERBOSE:
            print "[+] Test: test_check_path_satisfiability"

        # binary : stack1
        bin_start_address, bin_end_address = 0x08048ec0, 0x8048f02

        binary = "\x55"  # 0x08048ec0 : push   ebp
        binary += "\x89\xe5"  # 0x08048ec1 : mov    ebp,esp
        binary += "\x83\xec\x60"  # 0x08048ec3 : sub    esp,0x60
        binary += "\x8d\x45\xfc"  # 0x08048ec6 : lea    eax,[ebp-0x4]
        binary += "\x89\x44\x24\x08"  # 0x08048ec9 : mov    DWORD PTR [esp+0x8],eax
        binary += "\x8d\x45\xac"  # 0x08048ecd : lea    eax,[ebp-0x54]
        binary += "\x89\x44\x24\x04"  # 0x08048ed0 : mov    DWORD PTR [esp+0x4],eax
        binary += "\xc7\x04\x24\xa8\x5a\x0c\x08"  # 0x08048ed4 : mov    DWORD PTR [esp],0x80c5aa8
        binary += "\xe8\xa0\x0a\x00\x00"  # 0x08048edb : call   8049980 <_IO_printf>
        binary += "\x8d\x45\xac"  # 0x08048ee0 : lea    eax,[ebp-0x54]
        binary += "\x89\x04\x24"  # 0x08048ee3 : mov    DWORD PTR [esp],eax
        binary += "\xe8\xc5\x0a\x00\x00"  # 0x08048ee6 : call   80499b0 <_IO_gets>
        binary += "\x8b\x45\xfc"  # 0x08048eeb : mov    eax,DWORD PTR [ebp-0x4]
        binary += "\x3d\x44\x43\x42\x41"  # 0x08048eee : cmp    eax,0x41424344
        binary += "\x75\x0c"  # 0x08048ef3 : jne    8048f01 <main+0x41>
        binary += "\xc7\x04\x24\xc0\x5a\x0c\x08"  # 0x08048ef5 : mov    DWORD PTR [esp],0x80c5ac0
        binary += "\xe8\x4f\x0c\x00\x00"  # 0x08048efc : call   8049b50 <_IO_puts>
        binary += "\xc9"  # 0x08048f01 : leave
        binary += "\xc3"  # 0x08048f02 : ret

        self._memory.add_vma(bin_start_address, bytearray(binary))

        start = 0x08048ec0
        # start = 0x08048ec6
        # end = 0x08048efc
        end = 0x08048f01

        registers = {
            "eax": GenericRegister("eax", 32, 0xffffd0ec),
            "ecx": GenericRegister("ecx", 32, 0x00000001),
            "edx": GenericRegister("edx", 32, 0xffffd0e4),
            "ebx": GenericRegister("ebx", 32, 0x00000000),
            "esp": GenericRegister("esp", 32, 0xffffd05c),
            "ebp": GenericRegister("ebp", 32, 0x08049580),
            "esi": GenericRegister("esi", 32, 0x00000000),
            "edi": GenericRegister("edi", 32, 0x08049620),
            "eip": GenericRegister("eip", 32, 0x08048ec0),
        }

        flags = {
            "af": GenericFlag("af", 0x0),
            "cf": GenericFlag("cf", 0x0),
            "of": GenericFlag("of", 0x0),
            "pf": GenericFlag("pf", 0x1),
            "sf": GenericFlag("sf", 0x0),
            "zf": GenericFlag("zf", 0x1),
        }

        memory = {}

        bb_list, calls = self._bb_builder.build(bin_start_address,
                                                bin_end_address)
        bb_graph = ControlFlowGraph(bb_list)

        codeAnalyzer = CodeAnalyzer(self._smt_solver, self._smt_translator,
                                    self._arch_info)

        codeAnalyzer.set_context(GenericContext(registers, flags, memory))

        for bb_path in bb_graph.all_simple_bb_paths(start, end):
            if VERBOSE:
                print "[+] Checking path satisfiability :"
                print "      From : %s" % hex(start)
                print "      To : %s" % hex(end)
                print "      Path : %s" % " -> ".join(
                    (map(lambda o: hex(o.address), bb_path)))

            is_sat = codeAnalyzer.check_path_satisfiability(bb_path,
                                                            start,
                                                            verbose=False)

            if VERBOSE:
                print "[+] Satisfiability : %s" % str(is_sat)

            self.assertTrue(is_sat)

            if is_sat and VERBOSE:
                print codeAnalyzer.get_context()

            if VERBOSE:
                print ":" * 80
                print ""
class CodeAnalyzerTests(unittest.TestCase):

    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)
        self._operand_size = self._arch_info.operand_size
        self._memory = MemoryMock()
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver, self._operand_size)
        self._smt_translator.set_arch_alias_mapper(self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(self._arch_info.registers_size)
        self._disasm = X86Disassembler()
        self._ir_translator = X86Translator()
        self._bb_builder = BasicBlockBuilder(self._disasm, self._memory, self._ir_translator, self._arch_info)

    def test_check_path_satisfiability(self):
        if VERBOSE:
            print "[+] Test: test_check_path_satisfiability"

        # binary : stack1
        bin_start_address, bin_end_address = 0x08048ec0, 0x8048f02

        binary  = "\x55"                          # 0x08048ec0 : push   ebp
        binary += "\x89\xe5"                      # 0x08048ec1 : mov    ebp,esp
        binary += "\x83\xec\x60"                  # 0x08048ec3 : sub    esp,0x60
        binary += "\x8d\x45\xfc"                  # 0x08048ec6 : lea    eax,[ebp-0x4]
        binary += "\x89\x44\x24\x08"              # 0x08048ec9 : mov    DWORD PTR [esp+0x8],eax
        binary += "\x8d\x45\xac"                  # 0x08048ecd : lea    eax,[ebp-0x54]
        binary += "\x89\x44\x24\x04"              # 0x08048ed0 : mov    DWORD PTR [esp+0x4],eax
        binary += "\xc7\x04\x24\xa8\x5a\x0c\x08"  # 0x08048ed4 : mov    DWORD PTR [esp],0x80c5aa8
        binary += "\xe8\xa0\x0a\x00\x00"          # 0x08048edb : call   8049980 <_IO_printf>
        binary += "\x8d\x45\xac"                  # 0x08048ee0 : lea    eax,[ebp-0x54]
        binary += "\x89\x04\x24"                  # 0x08048ee3 : mov    DWORD PTR [esp],eax
        binary += "\xe8\xc5\x0a\x00\x00"          # 0x08048ee6 : call   80499b0 <_IO_gets>
        binary += "\x8b\x45\xfc"                  # 0x08048eeb : mov    eax,DWORD PTR [ebp-0x4]
        binary += "\x3d\x44\x43\x42\x41"          # 0x08048eee : cmp    eax,0x41424344
        binary += "\x75\x0c"                      # 0x08048ef3 : jne    8048f01 <main+0x41>
        binary += "\xc7\x04\x24\xc0\x5a\x0c\x08"  # 0x08048ef5 : mov    DWORD PTR [esp],0x80c5ac0
        binary += "\xe8\x4f\x0c\x00\x00"          # 0x08048efc : call   8049b50 <_IO_puts>
        binary += "\xc9"                          # 0x08048f01 : leave
        binary += "\xc3"                          # 0x08048f02 : ret

        self._memory.set_base_address(bin_start_address)
        self._memory.set_content(binary)

        start = 0x08048ec0
        # start = 0x08048ec6
        # end = 0x08048efc
        end = 0x08048f01

        registers = {
            "eax" : GenericRegister("eax", 32, 0xffffd0ec),
            "ecx" : GenericRegister("ecx", 32, 0x00000001),
            "edx" : GenericRegister("edx", 32, 0xffffd0e4),
            "ebx" : GenericRegister("ebx", 32, 0x00000000),
            "esp" : GenericRegister("esp", 32, 0xffffd05c),
            "ebp" : GenericRegister("ebp", 32, 0x08049580),
            "esi" : GenericRegister("esi", 32, 0x00000000),
            "edi" : GenericRegister("edi", 32, 0x08049620),
            "eip" : GenericRegister("eip", 32, 0x08048ec0),
        }

        flags = {
            "af" : GenericFlag("af", 0x0),
            "cf" : GenericFlag("cf", 0x0),
            "of" : GenericFlag("of", 0x0),
            "pf" : GenericFlag("pf", 0x1),
            "sf" : GenericFlag("sf", 0x0),
            "zf" : GenericFlag("zf", 0x1),
        }

        memory = {
        }

        bb_list = self._bb_builder.build(bin_start_address, bin_end_address)

        bb_graph = BasicBlockGraph(bb_list)
        # bb_graph.save("bb_graph.png")
        # bb_graph.save("bb_graph_ir.png", print_ir=True)

        codeAnalyzer = CodeAnalyzer(self._smt_solver, self._smt_translator, self._arch_info)

        codeAnalyzer.set_context(GenericContext(registers, flags, memory))

        for bb_path in bb_graph.all_simple_bb_paths(start, end):
            if VERBOSE:
                print "[+] Checking path satisfiability :"
                print "      From : %s" % hex(start)
                print "      To : %s" % hex(end)
                print "      Path : %s" % " -> ".join((map(lambda o : hex(o.address), bb_path)))

            is_sat = codeAnalyzer.check_path_satisfiability(bb_path, start, verbose=False)

            if VERBOSE:
                print "[+] Satisfiability : %s" % str(is_sat)

            self.assertTrue(is_sat)

            if is_sat and VERBOSE:
                print codeAnalyzer.get_context()

            if VERBOSE:
                print ":" * 80
                print ""
class SmtTranslatorTests(unittest.TestCase):
    def setUp(self):
        self._address_size = 32
        self._parser = ReilParser()
        self._solver = SmtSolver()
        self._translator = SmtTranslator(self._solver, self._address_size)

        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._translator.set_arch_alias_mapper(self._arch_info.alias_mapper)
        self._translator.set_arch_registers_size(
            self._arch_info.registers_size)

    # Arithmetic Instructions
    def test_translate_add_1(self):
        # Same size operands.
        instr = self._parser.parse(["add [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvadd t0_0 t1_0))")

    def test_translate_add_2(self):
        # Destination operand larger than source operands.
        instr = self._parser.parse(["add [BYTE t0, BYTE t1, WORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= t2_1 (bvadd ((_ zero_extend 8) t0_0) ((_ zero_extend 8) t1_0)))"
        )

    def test_translate_add_3(self):
        # Destination operand smaller than source operands.
        instr = self._parser.parse(["add [WORD t0, WORD t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value,
                         "(= t2_1 ((_ extract 7 0) (bvadd t0_0 t1_0)))")

    def test_translate_add_4(self):
        # Mixed source operands.
        instr = self._parser.parse(["add [BYTE t0, BYTE 0x12, WORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= t2_1 (bvadd ((_ zero_extend 8) t0_0) ((_ zero_extend 8) #x12)))"
        )

    def test_translate_sub(self):
        instr = self._parser.parse(["sub [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvsub t0_0 t1_0))")

    def test_translate_mul(self):
        instr = self._parser.parse(["mul [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvmul t0_0 t1_0))")

    def test_translate_div(self):
        instr = self._parser.parse(["div [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvudiv t0_0 t1_0))")

    def test_translate_mod(self):
        instr = self._parser.parse(["mod [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvurem t0_0 t1_0))")

    def test_translate_bsh(self):
        instr = self._parser.parse(["bsh [DWORD t0, DWORD t1, DWORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= t2_1 (ite (bvsge t1_0 #x00000000) (bvshl t0_0 t1_0) (bvlshr t0_0 (bvneg t1_0))))"
        )

    # Bitwise Instructions
    def test_translate_and(self):
        instr = self._parser.parse(["and [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvand t0_0 t1_0))")

    def test_translate_or(self):
        instr = self._parser.parse(["or [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvor t0_0 t1_0))")

    def test_translate_xor(self):
        instr = self._parser.parse(["xor [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvxor t0_0 t1_0))")

    # Data Transfer Instructions
    def test_translate_ldm(self):
        instr = self._parser.parse(["ldm [DWORD t0, empty, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value,
                         "(= (select MEM_0 (bvadd t0_0 #x00000000)) t2_1)")

    def test_translate_stm(self):
        instr = self._parser.parse(["stm [BYTE t0, empty, DWORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= MEM_1 (store MEM_0 (bvadd t2_0 #x00000000) t0_0))")

    def test_translate_str(self):
        instr = self._parser.parse(["str [BYTE t0, empty, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 t0_0)")

    # Conditional Instructions
    def test_translate_bisz(self):
        instr = self._parser.parse(["bisz [DWORD t0, empty, DWORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= t2_1 (ite (= t0_0 #x00000000) #x00000001 #x00000000))")

    def test_translate_jcc(self):
        instr = self._parser.parse(["jcc [BIT t0, empty, DWORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(not (= t0_0 #b0))")

    # Other Instructions
    def test_translate_undef(self):
        instr = self._parser.parse(["undef [empty, empty, DWORD t2]"])[0]

        with self.assertRaises(Exception) as context:
            self._translator.translate(instr)

        self.assertTrue("Unsupported instruction : UNDEF" in context.exception)

    def test_translate_unkn(self):
        instr = self._parser.parse(["unkn [empty, empty, empty]"])[0]

        with self.assertRaises(Exception) as context:
            self._translator.translate(instr)

        self.assertTrue("Unsupported instruction : UNKN" in context.exception)

    def test_translate_nop(self):
        instr = self._parser.parse(["nop [empty, empty, empty]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 0)

    # Extensions
    def test_translate_sext(self):
        instr = self._parser.parse(["sext [BYTE t0, empty, WORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 ((_ sign_extend 8) t0_0))")

    def test_translate_sdiv(self):
        instr = self._parser.parse(["sdiv [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvsdiv t0_0 t1_0))")

    def test_translate_smod(self):
        instr = self._parser.parse(["smod [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvsmod t0_0 t1_0))")
Example #4
0
class ArmGadgetClassifierTests(unittest.TestCase):

    def setUp(self):

        self._arch_info = ArmArchitectureInformation(ARCH_ARM_MODE_ARM)
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver, self._arch_info.address_size)

        self._ir_emulator = ReilEmulator(self._arch_info)

        self._smt_translator.set_arch_alias_mapper(self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(self._arch_info.registers_size)

        self._code_analyzer = CodeAnalyzer(self._smt_solver, self._smt_translator, self._arch_info)

        self._g_classifier = GadgetClassifier(self._ir_emulator, self._arch_info)
        self._g_verifier = GadgetVerifier(self._code_analyzer, self._arch_info)

    def _find_and_classify_gadgets(self, binary):
        g_finder = GadgetFinder(ArmDisassembler(architecture_mode=ARCH_ARM_MODE_ARM), bytearray(binary), ArmTranslator(), ARCH_ARM, ARCH_ARM_MODE_ARM)

        g_candidates = g_finder.find(0x00000000, len(binary), instrs_depth=4)
        g_classified = self._g_classifier.classify(g_candidates[0])

#         Debug:
#         self._print_candidates(g_candidates)
#         self._print_classified(g_classified)

        return g_candidates, g_classified

    def test_move_register_1(self):
        # testing : dst_reg <- src_reg
        binary  = b"\x04\x00\xa0\xe1"                     # 0x00 : (4)  mov    r0, r4
        binary += b"\x31\xff\x2f\xe1"                     # 0x04 : (4)  blx    r1

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 2)

        self.assertEqual(g_classified[0].type, GadgetType.MoveRegister)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 1)
        self.assertTrue(ReilRegisterOperand("r14", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_move_register_2(self):
        # testing : dst_reg <- src_reg
        binary  = b"\x00\x00\x84\xe2"                     # 0x00 : (4)  add    r0, r4, #0
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.MoveRegister)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    # TODO: test_move_register_n: mul r0, r4, #1

    def test_load_constant_1(self):
        # testing : dst_reg <- constant
        binary  = b"\x0a\x20\xa0\xe3"                     # 0x00 : (4)  mov    r2, #10
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEqual(g_classified[0].sources, [ReilImmediateOperand(10, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_2(self):
        # testing : dst_reg <- constant
        binary  = b"\x02\x20\x42\xe0"                     # 0x00 : (4)  sub    r2, r2, r2
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEqual(g_classified[0].sources, [ReilImmediateOperand(0, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_3(self):
        # testing : dst_reg <- constant
        binary  = b"\x02\x20\x22\xe0"                     # 0x00 : (4)  eor    r2, r2, r2
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEqual(g_classified[0].sources, [ReilImmediateOperand(0, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_4(self):
        # testing : dst_reg <- constant
        binary  = b"\x00\x20\x02\xe2"                     # 0x00 : (4)  and    r2, r2, #0
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEqual(g_classified[0].sources, [ReilImmediateOperand(0, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_5(self):
        # testing : dst_reg <- constant
        binary  = b"\x00\x20\x02\xe2"                     # and    r2, r2, #0
        binary += b"\x21\x20\x82\xe3"                     # orr    r2, r2, #33
        binary += b"\x1e\xff\x2f\xe1"                     # bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEqual(g_classified[0].sources, [ReilImmediateOperand(33, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_add_1(self):
        # testing : dst_reg <- src1_reg + src2_reg
        binary  = b"\x08\x00\x84\xe0"                     # 0x00 : (4)  add    r0, r4, r8
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.Arithmetic)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilRegisterOperand("r8", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEqual(g_classified[0].operation, "+")

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_sub_1(self):
        # testing : dst_reg <- src1_reg + src2_reg
        binary  = b"\x08\x00\x44\xe0"                     # 0x00 : (4)  sub    r0, r4, r8
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.Arithmetic)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilRegisterOperand("r8", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEqual(g_classified[0].operation, "-")

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_memory_1(self):
        # testing : dst_reg <- m[src_reg]
        binary  = b"\x00\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4]
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadMemory)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r3", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_memory_2(self):
        # testing : dst_reg <- m[src_reg + offset]
        binary  = b"\x33\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4 + 0x33]
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadMemory)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x33, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r3", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    # TODO: ARM's ldr rd, [rn, r2] is not a valid classification right now

    def test_store_memory_1(self):
        # testing : dst_reg <- m[src_reg]
        binary  = b"\x00\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4]
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.StoreMemory)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r3", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_store_memory_2(self):
        # testing : dst_reg <- m[src_reg + offset]
        binary  = b"\x33\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4 + 0x33]
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.StoreMemory)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r3", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x33, 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_load_add_1(self):
        # testing : dst_reg <- dst_reg + mem[src_reg]
        binary  = b"\x00\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4]
        binary += b"\x03\x00\x80\xe0"                     # 0x00 : (4)  add    r0, r0, r3
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 2)

        self.assertEqual(g_classified[0].type, GadgetType.ArithmeticLoad)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r0", 32), ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEqual(g_classified[0].operation, "+")

        self.assertEqual(len(g_classified[0].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r0", 32) in g_classified[0].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_load_add_2(self):
        # testing : dst_reg <- dst_reg + mem[src_reg + offset]
        binary  = b"\x22\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4, 0x22]
        binary += b"\x03\x00\x80\xe0"                     # 0x00 : (4)  add    r0, r0, r3
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 2)

        self.assertEqual(g_classified[0].type, GadgetType.ArithmeticLoad)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r0", 32), ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x22, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEqual(g_classified[0].operation, "+")

        self.assertEqual(len(g_classified[0].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r0", 32) in g_classified[0].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_store_add_1(self):
        # testing : m[dst_reg] <- m[dst_reg] + src_reg
        binary  = b"\x00\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4]
        binary += b"\x03\x30\x80\xe0"                     # 0x00 : (4)  add    r3, r0, r3
        binary += b"\x00\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4]
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 2)

        self.assertEqual(g_classified[0].type, GadgetType.ArithmeticStore)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32), ReilRegisterOperand("r0", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])
        self.assertEqual(g_classified[0].operation, "+")

        self.assertEqual(len(g_classified[0].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r4", 32) in g_classified[0].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_store_add_2(self):
        # testing : dst_reg <- dst_reg + mem[src_reg + offset]
        binary  = b"\x22\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4, 0x22]
        binary += b"\x03\x30\x80\xe0"                     # 0x00 : (4)  add    r3, r0, r3
        binary += b"\x22\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4, 0x22]
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 2)

        self.assertEqual(g_classified[0].type, GadgetType.ArithmeticStore)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x22, 32), ReilRegisterOperand("r0", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x22, 32)])
        self.assertEqual(g_classified[0].operation, "+")

        self.assertEqual(len(g_classified[0].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r4", 32) in g_classified[0].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def _print_candidates(self, candidates):
        print("Candidates :")

        for gadget in candidates:
            print(gadget)
            print("-" * 10)

    def _print_classified(self, classified):
        print("Classified :")

        for gadget in classified:
            print(gadget)
            print(gadget.type)
            print("-" * 10)
Example #5
0
class CodeAnalyzerTests(unittest.TestCase):
    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._smt_solver = SmtSolver()

        self._smt_translator = SmtTranslator(self._smt_solver,
                                             self._arch_info.address_size)
        self._smt_translator.set_arch_alias_mapper(
            self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(
            self._arch_info.registers_size)

        self._x86_parser = X86Parser(ARCH_X86_MODE_32)

        self._x86_translator = X86Translator(ARCH_X86_MODE_32)

        self._code_analyzer = CodeAnalyzer(self._smt_solver,
                                           self._smt_translator,
                                           self._arch_info)

    def test_add_reg_reg(self):
        # Parser x86 instructions.
        asm_instrs = [self._x86_parser.parse(i) for i in [
            "add eax, ebx",
        ]]

        # Add REIL instruction to the analyzer.
        for reil_instr in self.__asm_to_reil(asm_instrs):
            self._code_analyzer.add_instruction(reil_instr)

        # Add constraints.
        eax_pre = self._code_analyzer.get_register_expr("eax", mode="pre")
        eax_post = self._code_analyzer.get_register_expr("eax", mode="post")

        constraints = [
            eax_pre != 42,  # Pre-condition
            eax_post == 42,  # Post-condition
        ]

        for constr in constraints:
            self._code_analyzer.add_constraint(constr)

        # Assertions.
        self.assertEqual(self._code_analyzer.check(), 'sat')
        self.assertNotEqual(self._code_analyzer.get_expr_value(eax_pre), 42)
        self.assertEqual(self._code_analyzer.get_expr_value(eax_post), 42)

    def test_add_reg_mem(self):
        # Parser x86 instructions.
        asm_instrs = [
            self._x86_parser.parse(i) for i in [
                "add eax, [ebx + 0x1000]",
            ]
        ]

        # Add REIL instruction to the analyzer.
        for reil_instr in self.__asm_to_reil(asm_instrs):
            self._code_analyzer.add_instruction(reil_instr)

        # Add constraints.
        eax_pre = self._code_analyzer.get_register_expr("eax", mode="pre")
        eax_post = self._code_analyzer.get_register_expr("eax", mode="post")

        constraints = [
            eax_pre != 42,  # Pre-condition
            eax_post == 42,  # Post-condition
        ]

        for constr in constraints:
            self._code_analyzer.add_constraint(constr)

        # Assertions
        self.assertEqual(self._code_analyzer.check(), 'sat')
        self.assertNotEqual(self._code_analyzer.get_expr_value(eax_pre), 42)
        self.assertEqual(self._code_analyzer.get_expr_value(eax_post), 42)

    def test_add_mem_reg(self):
        # Parser x86 instructions.
        asm_instrs = [
            self._x86_parser.parse(i) for i in [
                "add [eax + 0x1000], ebx",
            ]
        ]

        # Add REIL instruction to the analyzer.
        for reil_instr in self.__asm_to_reil(asm_instrs):
            self._code_analyzer.add_instruction(reil_instr)

        # Add constraints.
        eax_pre = self._code_analyzer.get_register_expr("eax", mode="pre")

        mem_pre = self._code_analyzer.get_memory(mode="pre")
        mem_post = self._code_analyzer.get_memory(mode="post")

        constraints = [
            mem_pre[eax_pre + 0x1000] != 42,  # Pre-condition
            mem_post[eax_pre + 0x1000] == 42,  # Post-condition
        ]

        for constr in constraints:
            self._code_analyzer.add_constraint(constr)

        # Assertions.
        self.assertEqual(self._code_analyzer.check(), 'sat')
        self.assertNotEqual(
            self._code_analyzer.get_expr_value(mem_pre[eax_pre + 0x1000]), 42)
        self.assertEqual(
            self._code_analyzer.get_expr_value(mem_post[eax_pre + 0x1000]), 42)

    def __asm_to_reil(self, instructions):
        # Set address for each instruction.
        for addr, asm_instr in enumerate(instructions):
            asm_instr.address = addr

        # Translate to REIL instructions.
        reil_instrs = [self._x86_translator.translate(i) for i in instructions]

        # Flatten list and return
        reil_instrs = [instr for instrs in reil_instrs for instr in instrs]

        return reil_instrs
class ArmGadgetClassifierTests(unittest.TestCase):

    def setUp(self):

        self._arch_info = ArmArchitectureInformation(ARCH_ARM_MODE_ARM)
        self._smt_solver = SmtSolver()
        self._smt_translator = SmtTranslator(self._smt_solver, self._arch_info.address_size)

        self._ir_emulator = ReilEmulator(self._arch_info)

        self._smt_translator.set_arch_alias_mapper(self._arch_info.alias_mapper)
        self._smt_translator.set_arch_registers_size(self._arch_info.registers_size)

        self._code_analyzer = CodeAnalyzer(self._smt_solver, self._smt_translator, self._arch_info)

        self._g_classifier = GadgetClassifier(self._ir_emulator, self._arch_info)
        self._g_verifier = GadgetVerifier(self._code_analyzer, self._arch_info)

    def _find_and_classify_gadgets(self, binary):
        g_finder = GadgetFinder(ArmDisassembler(architecture_mode=ARCH_ARM_MODE_ARM), bytearray(binary), ArmTranslator(), ARCH_ARM, ARCH_ARM_MODE_ARM)

        g_candidates = g_finder.find(0x00000000, len(binary), instrs_depth=4)
        g_classified = self._g_classifier.classify(g_candidates[0])

#         Debug:
#         self._print_candidates(g_candidates)
#         self._print_classified(g_classified)

        return g_candidates, g_classified

    def test_move_register_1(self):
        # testing : dst_reg <- src_reg
        binary  = b"\x04\x00\xa0\xe1"                     # 0x00 : (4)  mov    r0, r4
        binary += b"\x31\xff\x2f\xe1"                     # 0x04 : (4)  blx    r1

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.MoveRegister)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_move_register_2(self):
        # testing : dst_reg <- src_reg
        binary  = b"\x00\x00\x84\xe2"                     # 0x00 : (4)  add    r0, r4, #0
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.MoveRegister)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    # TODO: test_move_register_n: mul r0, r4, #1

    def test_load_constant_1(self):
        # testing : dst_reg <- constant
        binary  = b"\x0a\x20\xa0\xe3"                     # 0x00 : (4)  mov    r2, #10
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEqual(g_classified[0].sources, [ReilImmediateOperand(10, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_2(self):
        # testing : dst_reg <- constant
        binary  = b"\x02\x20\x42\xe0"                     # 0x00 : (4)  sub    r2, r2, r2
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEqual(g_classified[0].sources, [ReilImmediateOperand(0, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_3(self):
        # testing : dst_reg <- constant
        binary  = b"\x02\x20\x22\xe0"                     # 0x00 : (4)  eor    r2, r2, r2
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEqual(g_classified[0].sources, [ReilImmediateOperand(0, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_4(self):
        # testing : dst_reg <- constant
        binary  = b"\x00\x20\x02\xe2"                     # 0x00 : (4)  and    r2, r2, #0
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEqual(g_classified[0].sources, [ReilImmediateOperand(0, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_constant_5(self):
        # testing : dst_reg <- constant
        binary  = b"\x00\x20\x02\xe2"                     # and    r2, r2, #0
        binary += b"\x21\x20\x82\xe3"                     # orr    r2, r2, #33
        binary += b"\x1e\xff\x2f\xe1"                     # bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadConstant)
        self.assertEqual(g_classified[0].sources, [ReilImmediateOperand(33, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r2", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)
        self.assertFalse(ReilRegisterOperand("r2", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_add_1(self):
        # testing : dst_reg <- src1_reg + src2_reg
        binary  = b"\x08\x00\x84\xe0"                     # 0x00 : (4)  add    r0, r4, r8
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.Arithmetic)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilRegisterOperand("r8", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEqual(g_classified[0].operation, "+")

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_sub_1(self):
        # testing : dst_reg <- src1_reg + src2_reg
        binary  = b"\x08\x00\x44\xe0"                     # 0x00 : (4)  sub    r0, r4, r8
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.Arithmetic)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilRegisterOperand("r8", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEqual(g_classified[0].operation, "-")

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_memory_1(self):
        # testing : dst_reg <- m[src_reg]
        binary  = b"\x00\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4]
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadMemory)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r3", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_load_memory_2(self):
        # testing : dst_reg <- m[src_reg + offset]
        binary  = b"\x33\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4 + 0x33]
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.LoadMemory)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x33, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r3", 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    # TODO: ARM's ldr rd, [rn, r2] is not a valid classification right now

    def test_store_memory_1(self):
        # testing : dst_reg <- m[src_reg]
        binary  = b"\x00\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4]
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.StoreMemory)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r3", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_store_memory_2(self):
        # testing : dst_reg <- m[src_reg + offset]
        binary  = b"\x33\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4 + 0x33]
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 1)

        self.assertEqual(g_classified[0].type, GadgetType.StoreMemory)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r3", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x33, 32)])

        self.assertEqual(len(g_classified[0].modified_registers), 0)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_load_add_1(self):
        # testing : dst_reg <- dst_reg + mem[src_reg]
        binary  = b"\x00\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4]
        binary += b"\x03\x00\x80\xe0"                     # 0x00 : (4)  add    r0, r0, r3
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 2)

        self.assertEqual(g_classified[0].type, GadgetType.ArithmeticLoad)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r0", 32), ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEqual(g_classified[0].operation, "+")

        self.assertEqual(len(g_classified[0].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r0", 32) in g_classified[0].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_load_add_2(self):
        # testing : dst_reg <- dst_reg + mem[src_reg + offset]
        binary  = b"\x22\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4, 0x22]
        binary += b"\x03\x00\x80\xe0"                     # 0x00 : (4)  add    r0, r0, r3
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 2)

        self.assertEqual(g_classified[0].type, GadgetType.ArithmeticLoad)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r0", 32), ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x22, 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r0", 32)])
        self.assertEqual(g_classified[0].operation, "+")

        self.assertEqual(len(g_classified[0].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r0", 32) in g_classified[0].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_store_add_1(self):
        # testing : m[dst_reg] <- m[dst_reg] + src_reg
        binary  = b"\x00\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4]
        binary += b"\x03\x30\x80\xe0"                     # 0x00 : (4)  add    r3, r0, r3
        binary += b"\x00\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4]
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 2)

        self.assertEqual(g_classified[0].type, GadgetType.ArithmeticStore)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32), ReilRegisterOperand("r0", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x0, 32)])
        self.assertEqual(g_classified[0].operation, "+")

        self.assertEqual(len(g_classified[0].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r4", 32) in g_classified[0].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def test_arithmetic_store_add_2(self):
        # testing : dst_reg <- dst_reg + mem[src_reg + offset]
        binary  = b"\x22\x30\x94\xe5"                     # 0x00 : (4)  ldr    r3, [r4, 0x22]
        binary += b"\x03\x30\x80\xe0"                     # 0x00 : (4)  add    r3, r0, r3
        binary += b"\x22\x30\x84\xe5"                     # 0x00 : (4)  str    r3, [r4, 0x22]
        binary += b"\x1e\xff\x2f\xe1"                     # 0x04 : (4)  bx     lr

        g_candidates, g_classified = self._find_and_classify_gadgets(binary)

        self.assertEqual(len(g_candidates), 1)
        self.assertEqual(len(g_classified), 2)

        self.assertEqual(g_classified[0].type, GadgetType.ArithmeticStore)
        self.assertEqual(g_classified[0].sources, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x22, 32), ReilRegisterOperand("r0", 32)])
        self.assertEqual(g_classified[0].destination, [ReilRegisterOperand("r4", 32), ReilImmediateOperand(0x22, 32)])
        self.assertEqual(g_classified[0].operation, "+")

        self.assertEqual(len(g_classified[0].modified_registers), 1)

        self.assertFalse(ReilRegisterOperand("r4", 32) in g_classified[0].modified_registers)
        self.assertTrue(ReilRegisterOperand("r3", 32) in g_classified[0].modified_registers)

        self.assertTrue(self._g_verifier.verify(g_classified[0]))

    def _print_candidates(self, candidates):
        print("Candidates :")

        for gadget in candidates:
            print(gadget)
            print("-" * 10)

    def _print_classified(self, classified):
        print("Classified :")

        for gadget in classified:
            print(gadget)
            print(gadget.type)
            print("-" * 10)
Example #7
0
class SymExecResult(object):

    def __init__(self, arch, initial_state, path, final_state):
        self.__initial_state = initial_state
        self.__path = path
        self.__final_state = final_state

        self.__arch = arch

        self.__smt_solver = None
        self.__smt_translator = None

        self.__code_analyzer = None

        self.__initialize_analyzer()

        self.__setup_solver()

    def query_register(self, register):
        # TODO: This method should return an iterator.

        smt_expr = self.__code_analyzer.get_register_expr(register, mode="pre")
        value = self.__code_analyzer.get_expr_value(smt_expr)

        return value

    def query_memory(self, address, size):
        # TODO: This method should return an iterator.

        smt_expr = self.__code_analyzer.get_memory_expr(address, size, mode="pre")
        value = self.__code_analyzer.get_expr_value(smt_expr)

        return value

    # Auxiliary methods
    # ======================================================================== #
    def __initialize_analyzer(self):
        self.__smt_solver = Z3Solver()

        self.__smt_translator = SmtTranslator(self.__smt_solver, self.__arch.address_size)
        self.__smt_translator.set_arch_alias_mapper(self.__arch.alias_mapper)
        self.__smt_translator.set_arch_registers_size(self.__arch.registers_size)

        self.__code_analyzer = CodeAnalyzer(self.__smt_solver, self.__smt_translator, self.__arch)

    def __setup_solver(self):
        self.__set_initial_state(self.__initial_state)
        self.__add_trace_to_solver(self.__path)
        self.__set_final_state(self.__final_state)

        assert self.__code_analyzer.check() == "sat"

    def __set_initial_state(self, initial_state):
        # Set registers
        for reg, val in initial_state.get_registers().items():
            smt_expr = self.__code_analyzer.get_register_expr(reg, mode="pre")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set memory
        for addr, val in initial_state.get_memory().items():
            smt_expr = self.__code_analyzer.get_memory_expr(addr, 1, mode="pre")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set constraints
        for constr in initial_state.get_constraints():
            self.__code_analyzer.add_constraint(constr)

    def __set_final_state(self, final_state):
        # Set registers
        for reg, val in final_state.get_registers().items():
            smt_expr = self.__code_analyzer.get_register_expr(reg, mode="post")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set memory
        for addr, val in final_state.get_memory().items():
            smt_expr = self.__code_analyzer.get_memory_expr(addr, 1, mode="post")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set constraints
        for constr in final_state.get_constraints():
            self.__code_analyzer.add_constraint(constr)

    def __add_trace_to_solver(self, trace):
        for reil_instr, branch_taken in trace:
            if reil_instr.mnemonic == ReilMnemonic.JCC and isinstance(reil_instr.operands[0], ReilRegisterOperand):
                oprnd = reil_instr.operands[0]
                oprnd_expr = self.__code_analyzer.get_operand_expr(oprnd)

                branch_expr = oprnd_expr != 0x0 if branch_taken else oprnd_expr == 0x0

                # logger.debug("    Branch: {:#010x}:{:02x}  {:s} ({}) - {:s}".format(reil_instr.address >> 8, reil_instr.address & 0xff, reil_instr, branch_taken, branch_expr))

                self.__code_analyzer.add_constraint(branch_expr)
            else:
                self.__code_analyzer.add_instruction(reil_instr)
Example #8
0
class ReilSymbolicEmulator(object):

    def __init__(self, arch):
        self.__arch = arch

        self.__memory = ReilMemoryEx(self.__arch.address_size)

        self.__tainter = ReilEmulatorTainter(self, arch=self.__arch)

        self.__emulator = ReilEmulator(self.__arch)

        self.__cpu = ReilCpu(self.__memory, arch=self.__arch)

        self.__smt_solver = None
        self.__smt_translator = None

        self.__code_analyzer = None

        self.__initialize_analyzer()

    def find_address(self, container, start=None, end=None, find=None, avoid=None, initial_state=None):
        """Execute instructions.
        """
        # Set initial CPU state.
        self.__set_cpu_state(initial_state)

        # Convert input native addresses to reil addresses.
        start = to_reil_address(start) if start else None
        end = to_reil_address(end) if end else None
        find = to_reil_address(find) if find else None
        avoid = [to_reil_address(addr) for addr in avoid] if avoid else []

        # Load instruction pointer.
        ip = start if start else container[0].address

        execution_state = Queue()

        trace_current = []
        trace_final = []

        self.__fa_process_container(container, find, ip, end, avoid, initial_state, execution_state, trace_current,
                                    trace_final)

        # Only returns when all paths have been visited.
        assert execution_state.empty()

        return trace_final

    def find_state(self, container, start=None, end=None, avoid=None, initial_state=None, final_state=None):
        """Execute instructions.
        """
        self.__set_cpu_state(initial_state)

        # Convert input native addresses to reil addresses.
        start = to_reil_address(start) if start else None
        end = to_reil_address(end) if end else None
        avoid = [to_reil_address(addr) for addr in avoid] if avoid else []

        # Load instruction pointer.
        ip = start if start else container[0].address

        execution_state = Queue()

        trace_current = []
        trace_final = []

        self.__fs_process_container(container, final_state, ip, end, avoid, initial_state, execution_state,
                                    trace_current, trace_final)

        # Only returns when all paths have been visited.
        assert execution_state.empty()

        return trace_final

    # Read/Write methods
    # ======================================================================== #
    def read_operand(self, operand):
        return self.__cpu.read_operand(operand)

    def write_operand(self, operand, value):
        self.__cpu.write_operand(operand, value)

    def read_memory(self, address, size):
        return self.__memory.read(address, size)

    def write_memory(self, address, size, value):
        self.__memory.write(address, size, value)

    # Auxiliary methods.
    # ======================================================================== #
    def __process_branch_direct(self, trace_current, target_addr, avoid, instr, execution_state, initial_state, taken):
        taken_str = "TAKEN" if taken else "NOT TAKEN"

        if target_addr in avoid:
            logger.debug("[+] Avoiding target address ({:s}) : {:#08x}:{:02x}".format(taken_str, target_addr >> 8, target_addr & 0xff))
        else:
            logger.debug("[+] Checking target satisfiability ({:s}) : {:#08x}:{:02x} -> {:#08x}:{:02x}".format(taken_str, instr.address >> 8, instr.address & 0xff, target_addr >> 8, target_addr & 0xff))

            trace_current += [(instr, taken)]

            self.__reset_solver()
            self.__set_initial_state(initial_state)
            self.__add_trace_to_solver(trace_current)

            is_sat = self.__code_analyzer.check()

            logger.debug("[+] Target satisfiable? : {}".format(is_sat == 'sat'))

            if is_sat == 'sat':
                logger.debug("[+] Enqueueing target address ({:s}) : {:#08x}:{:02x}".format(taken_str, target_addr >> 8, target_addr & 0xff))

                execution_state.put((target_addr, trace_current, copy.deepcopy(self.__cpu.registers), copy.deepcopy(self.__cpu.memory)))

    def __process_branch_cond(self, instr, avoid, initial_state, execution_state, trace_current, not_taken_addr):
        # Direct branch (for example: JCC cond, empty, 0x12345678:00)
        if isinstance(instr.operands[2], ReilImmediateOperand):
            # TAKEN
            # ======================================================== #
            trace_current_bck = list(trace_current)

            target_addr = instr.operands[2].immediate

            self.__process_branch_direct(trace_current, target_addr, avoid, instr, execution_state, initial_state,
                                         True)

            # NOT TAKEN
            # ======================================================== #
            trace_current = trace_current_bck

            target_addr = not_taken_addr

            self.__process_branch_direct(trace_current, target_addr, avoid, instr, execution_state, initial_state,
                                         False)
            # ======================================================== #

            next_ip = None

        # Indirect branch (for example: JCC cond, empty, target)
        else:
            raise Exception("Indirect jump not supported yet.")

        return next_ip

    def __process_branch_uncond(self, instr, trace_current, not_taken_addr):
        # TAKEN branch (for example: JCC 0x1, empty, oprnd2).
        if instr.operands[0].immediate != 0x0:
            # Direct branch (for example: JCC 0x1, empty, INTEGER)
            if isinstance(instr.operands[2], ReilImmediateOperand):
                trace_current += [(instr, None)]

                next_ip = self.__cpu.execute(instr)

            # Indirect branch (for example: JCC 0x1, empty, REGISTER)
            else:
                raise Exception("Indirect jump not supported yet.")

        # NOT TAKEN branch (for example: JCC 0x0, empty, oprnd2).
        else:
            next_ip = not_taken_addr

        return next_ip

    def __process_instr(self, instr, avoid, next_addr, initial_state, execution_state, trace_current):
        """Process a REIL instruction.

        Args:
            instr (ReilInstruction): Instruction to process.
            avoid (list): List of addresses to avoid while executing the code.
            next_addr (int): Address of the following instruction.
            initial_state (State): Initial execution state.
            execution_state (Queue): Queue of execution states.
            trace_current (list): Current trace.

        Returns:
            int: Returns the next address to execute.
        """
        # Process branch (JCC oprnd0, empty, oprnd2).
        if instr.mnemonic == ReilMnemonic.JCC:
            not_taken_addr = next_addr
            address, index = split_address(instr.address)

            logger.debug("[+] Processing branch: {:#08x}:{:02x} : {}".format(address, index, instr))

            # Process conditional branch (oprnd0 is a REGISTER).
            if isinstance(instr.operands[0], ReilRegisterOperand):
                next_ip = self.__process_branch_cond(instr, avoid, initial_state, execution_state, trace_current, not_taken_addr)

            # Process unconditional branch (oprnd0 is an INTEGER).
            else:
                next_ip = self.__process_branch_uncond(instr, trace_current, not_taken_addr)

        # Process the rest of the instructions.
        else:
            trace_current += [(instr, None)]

            self.__cpu.execute(instr)

            next_ip = next_addr

        return next_ip

    # find_state's auxiliary methods
    # ======================================================================== #
    def __fs_process_container(self, container, final_state, start, end, avoid, initial_state, execution_state,
                               trace_current, trace_final):
        ip = start

        while ip:
            # Fetch next instruction.
            try:
                instr = container.fetch(ip)
            except ReilContainerInvalidAddressError:
                logger.debug("Exception @ {:#08x}".format(ip))

                raise ReilContainerInvalidAddressError

            # Compute the address of the following instruction to the fetched one.
            try:
                next_addr = container.get_next_address(ip)
            except Exception:
                logger.debug("Exception @ {:#08x}".format(ip))

                # TODO Should this be considered an error?

                raise ReilContainerInvalidAddressError

            # Process instruction
            next_ip = self.__process_instr(instr, avoid, next_addr, initial_state, execution_state, trace_current)

            # Check is final state holds.
            if instr.mnemonic == ReilMnemonic.JCC and isinstance(instr.operands[0], ReilRegisterOperand):
                # TODO Check only when it is necessary.
                self.__reset_solver()
                self.__set_initial_state(initial_state)
                self.__add_trace_to_solver(trace_current)
                self.__set_final_state(final_state)

                is_sat = self.__code_analyzer.check()

                if is_sat == "sat":
                    logger.debug("[+] Final state found!")

                    trace_final.append(list(trace_current))

                    next_ip = None

            # Check termination conditions.
            if end and next_ip and next_ip == end:
                logger.debug("[+] End address found!")

                next_ip = None

            # Update instruction pointer.
            ip = next_ip if next_ip else None

            while not ip:
                if not execution_state.empty():
                    # Pop next execution state.
                    ip, trace_current, registers, memory = execution_state.get()

                    if split_address(ip)[1] == 0x0:
                        logger.debug("[+] Popping execution state @ {:#x} (INTER)".format(ip))
                    else:
                        logger.debug("[+] Popping execution state @ {:#x} (INTRA)".format(ip))

                    # Setup cpu and memory.
                    self.__cpu.registers = registers
                    self.__cpu.memory = memory

                    logger.debug("[+] Next address: {:#08x}:{:02x}".format(ip >> 8, ip & 0xff))
                else:
                    logger.debug("[+] No more paths to explore! Exiting...")
                    break

                # Check termination conditions (AGAIN...).
                if end and ip == end:
                    logger.debug("[+] End address found!")

                    ip = None

    # find_address's auxiliary methods
    # ======================================================================== #
    def __fa_process_sequence(self, sequence, avoid, initial_state, execution_state, trace_current, next_addr):
        """Process a REIL sequence.

        Args:
            sequence (ReilSequence): A REIL sequence to process.
            avoid (list): List of address to avoid.
            initial_state: Initial state.
            execution_state: Execution state queue.
            trace_current (list): Current trace.
            next_addr: Address of the next instruction following the current one.

        Returns:
            Returns the next instruction to execute in case there is one, otherwise returns None.
        """
        # TODO: Process execution intra states.

        ip = sequence.address
        next_ip = None

        while ip:
            # Fetch next instruction in the sequence.
            try:
                instr = sequence.fetch(ip)
            except ReilSequenceInvalidAddressError:
                # At this point, ip should be a native instruction address, therefore
                # the index should be zero.
                assert split_address(ip)[1] == 0x0
                next_ip = ip
                break

            try:
                target_addr = sequence.get_next_address(ip)
            except ReilSequenceInvalidAddressError:
                # We reached the end of the sequence. Execution continues on the next native instruction
                # (it's a REIL address).
                target_addr = next_addr

            next_ip = self.__process_instr(instr, avoid, target_addr, initial_state, execution_state, trace_current)

            # Update instruction pointer.
            try:
                ip = next_ip if next_ip else sequence.get_next_address(ip)
            except ReilSequenceInvalidAddressError:
                break

        return next_ip

    def __fa_process_container(self, container, find, start, end, avoid, initial_state, execution_state, trace_current,
                               trace_final):
        """Process a REIL container.

        Args:
            avoid (list): List of addresses to avoid while executing the code.
            container (ReilContainer): REIL container to execute.
            end (int): End address.
            execution_state (Queue): Queue of execution states.
            find (int): Address to find.
            initial_state (State): Initial state.
            start (int): Start address.
            trace_current:
            trace_final:
        """
        ip = start

        while ip:
            # NOTE *ip* and *next_addr* variables can be, independently, either intra
            # or inter addresses.

            # Fetch next instruction.
            try:
                instr = container.fetch(ip)
            except ReilContainerInvalidAddressError:
                logger.debug("Exception @ {:#08x}".format(ip))

                raise ReilContainerInvalidAddressError

            # Compute the address of the following instruction to the fetched one.
            try:
                next_addr = container.get_next_address(ip)
            except Exception:
                logger.debug("Exception @ {:#08x}".format(ip))

                # TODO Should this be considered an error?

                raise ReilContainerInvalidAddressError

            # Process the instruction.
            next_ip = self.__process_instr(instr, avoid, next_addr, initial_state, execution_state, trace_current)

            # # ====================================================================================================== #
            # # NOTE This is an attempt to separate intra and inter instruction
            # # addresses processing. Here, *ip* and *next_addr* are always inter
            # # instruction addresses.
            #
            # assert split_address(ip)[1] == 0x0
            #
            # # Compute the address of the following instruction to the fetched one.
            # try:
            #     seq = container.fetch_sequence(ip)
            # except ReilContainerInvalidAddressError:
            #     logger.debug("Exception @ {:#08x}".format(ip))
            #
            #     raise ReilContainerInvalidAddressError
            #
            # # Fetch next instruction address.
            # try:
            #     next_addr = container.get_next_address(ip + len(seq))
            # except Exception:
            #     logger.debug("Exception @ {:#08x}".format(ip))
            #
            #     # TODO Should this be considered an error?
            #
            #     raise ReilContainerInvalidAddressError
            #
            # next_ip = self.__process_sequence(seq, avoid, initial_state, execution_state, trace_current, next_addr)
            #
            # if next_ip:
            #     assert split_address(next_ip)[1] == 0x0
            #
            # # ====================================================================================================== #

            # Check termination conditions.
            if find and next_ip and next_ip == find:
                logger.debug("[+] Find address found!")

                trace_final.append(list(trace_current))

                next_ip = None

            if end and next_ip and next_ip == end:
                logger.debug("[+] End address found!")

                next_ip = None

            # Update instruction pointer.
            ip = next_ip if next_ip else None

            while not ip:
                if not execution_state.empty():
                    # Pop next execution state.
                    ip, trace_current, registers, memory = execution_state.get()

                    if split_address(ip)[1] == 0x0:
                        logger.debug("[+] Popping execution state @ {:#x} (INTER)".format(ip))
                    else:
                        logger.debug("[+] Popping execution state @ {:#x} (INTRA)".format(ip))

                    # Setup cpu and memory.
                    self.__cpu.registers = registers
                    self.__cpu.memory = memory

                    logger.debug("[+] Next address: {:#08x}:{:02x}".format(ip >> 8, ip & 0xff))
                else:
                    logger.debug("[+] No more paths to explore! Exiting...")

                    break

                # Check termination conditions (AGAIN...).
                if find and ip == find:
                    logger.debug("[+] Find address found!")

                    trace_final.append(list(trace_current))

                    ip = None

                if end and ip == end:
                    logger.debug("[+] End address found!")

                    ip = None

    # Auxiliary methods
    # ======================================================================== #
    def __initialize_analyzer(self):
        self.__smt_solver = Z3Solver()

        self.__smt_translator = SmtTranslator(self.__smt_solver, self.__arch.address_size)
        self.__smt_translator.set_arch_alias_mapper(self.__arch.alias_mapper)
        self.__smt_translator.set_arch_registers_size(self.__arch.registers_size)

        self.__code_analyzer = CodeAnalyzer(self.__smt_solver, self.__smt_translator, self.__arch)

    def __reset_solver(self):
        self.__code_analyzer.reset()

    def __set_cpu_state(self, state):
        # Set registers
        for reg, val in state.get_registers().items():
            self.__cpu.registers[reg] = val

        # Set memory
        for addr, val in state.get_memory().items():
            self.__cpu.write_memory(addr, 1, val)

    def __set_initial_state(self, initial_state):
        # Set registers
        for reg, val in initial_state.get_registers().items():
            smt_expr = self.__code_analyzer.get_register_expr(reg, mode="pre")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set memory
        for addr, val in initial_state.get_memory().items():
            smt_expr = self.__code_analyzer.get_memory_expr(addr, 1, mode="pre")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set constraints
        for constr in initial_state.get_constraints():
            self.__code_analyzer.add_constraint(constr)

    def __set_final_state(self, final_state):
        # Set registers
        for reg, val in final_state.get_registers().items():
            smt_expr = self.__code_analyzer.get_register_expr(reg, mode="post")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set memory
        for addr, val in final_state.get_memory().items():
            smt_expr = self.__code_analyzer.get_memory_expr(addr, 1, mode="post")
            self.__code_analyzer.add_constraint(smt_expr == val)

        # Set constraints
        for constr in final_state.get_constraints():
            self.__code_analyzer.add_constraint(constr)

    def __add_trace_to_solver(self, trace):
        for reil_instr, branch_taken in trace:
            if reil_instr.mnemonic == ReilMnemonic.JCC and isinstance(reil_instr.operands[0], ReilRegisterOperand):
                oprnd = reil_instr.operands[0]
                oprnd_expr = self.__code_analyzer.get_operand_expr(oprnd)

                branch_expr = oprnd_expr != 0x0 if branch_taken else oprnd_expr == 0x0

                # logger.debug("    Branch: {:#010x}:{:02x}  {:s} ({}) - {:s}".format(reil_instr.address >> 8, reil_instr.address & 0xff, reil_instr, branch_taken, branch_expr))

                self.__code_analyzer.add_constraint(branch_expr)
            else:
                self.__code_analyzer.add_instruction(reil_instr)
Example #9
0
class State(object):

    def __init__(self, arch, mode=None):
        self._registers = {}
        self._memory = {}
        self._constraints = []
        self._mode = mode   # {"initial", "final"}

        self.__arch = arch

        self.__smt_solver = None
        self.__smt_translator = None

        self.__code_analyzer = None

        self.__initialize_analyzer()

    def read_register(self, register):
        return self._registers.get(register, None)

    def write_register(self, register, value):
        self._registers[register] = value

    def query_register(self, register):
        smt_expr = self.__code_analyzer.get_register_expr(register, mode="pre")

        return smt_expr

    def get_registers(self):
        return dict(self._registers)

    def read_memory(self, address, size):
        assert size == 1

        return self._memory.get(address, None)

    def write_memory(self, address, size, value):
        for i in range(0, size):
            self._memory[address + i] = (value >> (i * 8)) & 0xff

    def query_memory(self, address, size):
        smt_expr = self.__code_analyzer.get_memory_expr(address, size, mode="pre")

        return smt_expr

    def get_memory(self):
        return dict(self._memory)

    def add_constraint(self, constraint):
        self._constraints.append(constraint)

    def get_constraints(self):
        return list(self._constraints)

    # Auxiliary methods
    # ======================================================================== #
    def __initialize_analyzer(self):
        self.__smt_solver = Z3Solver()

        self.__smt_translator = SmtTranslator(self.__smt_solver, self.__arch.address_size)
        self.__smt_translator.set_arch_alias_mapper(self.__arch.alias_mapper)
        self.__smt_translator.set_arch_registers_size(self.__arch.registers_size)

        self.__code_analyzer = CodeAnalyzer(self.__smt_solver, self.__smt_translator, self.__arch)