class ReilParserTests(unittest.TestCase):
    def setUp(self):
        self._parser = ReilParser()

    def test_add(self):
        instrs = ["str [eax, EMPTY, t0]"]
        instrs += ["str [ebx, EMPTY, t1]"]
        instrs += ["add [t0, t1, t2]"]
        instrs += ["str [t2, EMPTY, eax]"]

        instrs_parse = self._parser.parse(instrs)

        self.assertEqual(str(instrs_parse[0]), "str   [UNK eax, EMPTY, UNK t0]")
        self.assertEqual(str(instrs_parse[1]), "str   [UNK ebx, EMPTY, UNK t1]")
        self.assertEqual(str(instrs_parse[2]), "add   [UNK t0, UNK t1, UNK t2]")
        self.assertEqual(str(instrs_parse[3]), "str   [UNK t2, EMPTY, UNK eax]")

    def test_parse_operand_size(self):
        instrs = ["str [DWORD eax, EMPTY, DWORD t0]"]
        instrs += ["str [eax, EMPTY, DWORD t0]"]
        instrs += ["str [eax, EMPTY, t0]"]

        instrs_parse = self._parser.parse(instrs)

        self.assertEqual(instrs_parse[0].operands[0].size, 32)
        self.assertEqual(instrs_parse[0].operands[1].size, 0)
        self.assertEqual(instrs_parse[0].operands[2].size, 32)

        self.assertEqual(instrs_parse[1].operands[0].size, None)
        self.assertEqual(instrs_parse[1].operands[1].size, 0)
        self.assertEqual(instrs_parse[1].operands[2].size, 32)

        self.assertEqual(instrs_parse[2].operands[0].size, None)
        self.assertEqual(instrs_parse[2].operands[1].size, 0)
        self.assertEqual(instrs_parse[2].operands[2].size, None)
Exemplo n.º 2
0
class ReilParserTests(unittest.TestCase):
    def setUp(self):
        self._parser = ReilParser()

    def test_add(self):
        instrs = ["str [eax, EMPTY, t0]"]
        instrs += ["str [ebx, EMPTY, t1]"]
        instrs += ["add [t0, t1, t2]"]
        instrs += ["str [t2, EMPTY, eax]"]

        instrs_parse = self._parser.parse(instrs)

        self.assertEqual(str(instrs_parse[0]),
                         "str   [UNK eax, EMPTY, UNK t0]")
        self.assertEqual(str(instrs_parse[1]),
                         "str   [UNK ebx, EMPTY, UNK t1]")
        self.assertEqual(str(instrs_parse[2]),
                         "add   [UNK t0, UNK t1, UNK t2]")
        self.assertEqual(str(instrs_parse[3]),
                         "str   [UNK t2, EMPTY, UNK eax]")

    def test_parse_operand_size(self):
        instrs = ["str [DWORD eax, EMPTY, DWORD t0]"]
        instrs += ["str [eax, EMPTY, DWORD t0]"]
        instrs += ["str [eax, EMPTY, t0]"]

        instrs_parse = self._parser.parse(instrs)

        self.assertEqual(instrs_parse[0].operands[0].size, 32)
        self.assertEqual(instrs_parse[0].operands[1].size, 0)
        self.assertEqual(instrs_parse[0].operands[2].size, 32)

        self.assertEqual(instrs_parse[1].operands[0].size, None)
        self.assertEqual(instrs_parse[1].operands[1].size, 0)
        self.assertEqual(instrs_parse[1].operands[2].size, 32)

        self.assertEqual(instrs_parse[2].operands[0].size, None)
        self.assertEqual(instrs_parse[2].operands[1].size, 0)
        self.assertEqual(instrs_parse[2].operands[2].size, None)
Exemplo n.º 3
0
class ReilEmulatorTests(unittest.TestCase):
    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._emulator = ReilEmulator(self._arch_info)

        self._asm_parser = X86Parser()
        self._reil_parser = ReilParser()

        self._translator = X86Translator()

    def test_add(self):
        asm_instrs = self._asm_parser.parse("add eax, ebx")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax": 0x1,
            "ebx": 0x2,
        }

        regs_final, _ = self._emulator.execute_lite(reil_instrs,
                                                    context=regs_initial)

        self.assertEqual(regs_final["eax"], 0x3)
        self.assertEqual(regs_final["ebx"], 0x2)

    def test_loop(self):
        # 0x08048060 : b8 00 00 00 00   mov eax,0x0
        # 0x08048065 : bb 0a 00 00 00   mov ebx,0xa
        # 0x0804806a : 83 c0 01         add eax,0x1
        # 0x0804806d : 83 eb 01         sub ebx,0x1
        # 0x08048070 : 83 fb 00         cmp ebx,0x0
        # 0x08048073 : 75 f5            jne 0x0804806a

        asm_instrs_str = [(0x08048060, "mov eax,0x0", 5)]
        asm_instrs_str += [(0x08048065, "mov ebx,0xa", 5)]
        asm_instrs_str += [(0x0804806a, "add eax,0x1", 3)]
        asm_instrs_str += [(0x0804806d, "sub ebx,0x1", 3)]
        asm_instrs_str += [(0x08048070, "cmp ebx,0x0", 3)]
        asm_instrs_str += [(0x08048073, "jne 0x0804806a", 2)]

        asm_instrs = []

        for addr, asm, size in asm_instrs_str:
            asm_instr = self._asm_parser.parse(asm)
            asm_instr.address = addr
            asm_instr.size = size

            asm_instrs.append(asm_instr)

        reil_instrs = self.__translate(asm_instrs)

        regs_final, _ = self._emulator.execute(reil_instrs,
                                               start=0x08048060 << 8)

        self.assertEqual(regs_final["eax"], 0xa)
        self.assertEqual(regs_final["ebx"], 0x0)

    def test_mov(self):
        asm_instrs = [self._asm_parser.parse("mov eax, 0xdeadbeef")]
        asm_instrs += [self._asm_parser.parse("mov al, 0x12")]
        asm_instrs += [self._asm_parser.parse("mov ah, 0x34")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs = self._translator.translate(asm_instrs[0])
        reil_instrs += self._translator.translate(asm_instrs[1])
        reil_instrs += self._translator.translate(asm_instrs[2])

        regs_initial = {
            "eax": 0xffffffff,
        }

        regs_final, _ = self._emulator.execute_lite(reil_instrs,
                                                    context=regs_initial)

        self.assertEqual(regs_final["eax"], 0xdead3412)

    def test_pre_hanlder(self):
        def pre_hanlder(emulator, instruction, parameter):
            paramter.append(True)

        asm = ["mov eax, ebx"]

        x86_instrs = map(self._asm_parser.parse, asm)
        self.__set_address(0xdeadbeef, x86_instrs)
        reil_instrs = map(self._translator.translate, x86_instrs)

        paramter = []

        self._emulator.set_instruction_pre_handler(pre_hanlder, paramter)

        reil_ctx_out, reil_mem_out = self._emulator.execute_lite(
            reil_instrs[0])

        self.assertTrue(len(paramter) > 0)

    def test_post_hanlder(self):
        def post_hanlder(emulator, instruction, parameter):
            paramter.append(True)

        asm = ["mov eax, ebx"]

        x86_instrs = map(self._asm_parser.parse, asm)
        self.__set_address(0xdeadbeef, x86_instrs)
        reil_instrs = map(self._translator.translate, x86_instrs)

        paramter = []

        self._emulator.set_instruction_post_handler(post_hanlder, paramter)

        reil_ctx_out, reil_mem_out = self._emulator.execute_lite(
            reil_instrs[0])

        self.assertTrue(len(paramter) > 0)

    def test_zero_division_error_1(self):
        asm_instrs = [self._asm_parser.parse("div ebx")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs = self._translator.translate(asm_instrs[0])

        regs_initial = {
            "eax": 0x2,
            "edx": 0x2,
            "ebx": 0x0,
        }

        self.assertRaises(ReilCpuZeroDivisionError,
                          self._emulator.execute_lite,
                          reil_instrs,
                          context=regs_initial)

    def test_zero_division_error_2(self):
        instrs = ["mod [DWORD eax, DWORD ebx, DWORD t0]"]

        reil_instrs = self._reil_parser.parse(instrs)

        reil_instrs[0].address = 0xdeadbeef00

        regs_initial = {
            "eax": 0x2,
            "ebx": 0x0,
        }

        self.assertRaises(ReilCpuZeroDivisionError,
                          self._emulator.execute_lite,
                          reil_instrs,
                          context=regs_initial)

    def test_invalid_address_error_1(self):
        asm_instrs = [self._asm_parser.parse("jmp eax")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs = self.__translate(asm_instrs)

        regs_initial = {
            "eax": 0xffffffff,
        }

        self.assertRaises(ReilCpuInvalidAddressError,
                          self._emulator.execute,
                          reil_instrs,
                          start=0xdeadbeef << 8,
                          registers=regs_initial)

    def test_invalid_address_error_2(self):
        asm_instrs = [self._asm_parser.parse("mov eax, 0xdeadbeef")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs = self.__translate(asm_instrs)

        regs_initial = {
            "eax": 0xffffffff,
        }

        self.assertRaises(ReilCpuInvalidAddressError,
                          self._emulator.execute,
                          reil_instrs,
                          start=0xdeadbef0 << 8,
                          registers=regs_initial)

    # Auxiliary methods
    # ======================================================================== #
    def __set_address(self, address, asm_instrs):
        addr = address

        for asm_instr in asm_instrs:
            asm_instr.address = addr
            addr += 1

    def __translate(self, asm_instrs):
        instr_container = ReilContainer()

        asm_instr_last = None
        instr_seq_prev = None

        for asm_instr in asm_instrs:
            instr_seq = ReilSequence()

            for reil_instr in self._translator.translate(asm_instr):
                instr_seq.append(reil_instr)

            if instr_seq_prev:
                instr_seq_prev.next_sequence_address = instr_seq.address

            instr_container.add(instr_seq)

            instr_seq_prev = instr_seq

        if instr_seq_prev:
            if asm_instr_last:
                instr_seq_prev.next_sequence_address = (
                    asm_instr_last.address + asm_instr_last.size) << 8

        # instr_container.dump()

        return instr_container
Exemplo n.º 4
0
class SmtTranslatorTests(unittest.TestCase):

    def setUp(self):
        self._address_size = 32
        self._parser = ReilParser()
        self._solver = SmtSolver()
        self._translator = SmtTranslator(self._solver, self._address_size)

    def test_add_reg_reg(self):
        if VERBOSE:
            print "\n[+] Test: test_add_reg_reg"

        # add eax, ebx
        instrs = self._parser.parse([
            "add [eax, ebx, t0]",
            "str [t0, e, eax]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[1].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[2].size = 32

        self._solver.reset()

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("eax")) == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("eax")) != 42,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("ebx"))

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_add_reg_mem(self):
        if VERBOSE:
            print "\n[+] Test: test_add_reg_mem"

        # add eax, [ebx]
        instrs = self._parser.parse([
            "ldm [ebx, EMPTY, t0]",
            "add [eax, t0, t1]",
            "str [t1, EMPTY, eax]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[1].size = 32
        instrs[1].operands[2].size = 32

        instrs[2].operands[0].size = 32
        instrs[2].operands[2].size = 32

        self._solver.reset()

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("eax")) == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("eax")) != 42,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("ebx"))

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_add_mem_reg(self):
        if VERBOSE:
            print "\n[+] Test: test_add_mem_reg"

        # add [eax], ebx
        instrs = self._parser.parse([
            "ldm [eax, EMPTY, t0]",
            "add [t0, ebx, t1]",
            "stm [t1, EMPTY, eax]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[1].size = 32
        instrs[1].operands[2].size = 32

        instrs[2].operands[0].size = 32
        instrs[2].operands[2].size = 32

        self._solver.reset()

        # add constrains
        mem = self._translator.get_memory()
        eax = self._solver.mkBitVec(32, "eax_0")

        constraint = (mem[eax] != 42)

        if VERBOSE:
            print "constraint : %s" % constraint

        self._solver.add(constraint)

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            mem[eax] == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("t0")) != 42,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("ebx"))
                print "    t0 : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("t0"))
                print "    [eax] : 0x%08x" % self._solver.getvalue(mem[eax])

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_add_mem_reg_2(self):
        if VERBOSE:
            print "\n[+] Test: test_add_mem_reg_2"

        # add [eax + 0x1000], ebx
        instrs = self._parser.parse([
            "add [eax, 0x1000, t0]",
            "ldm [t0, e, t1]",
            "add [t1, ebx, t2]",
            "stm [t2, e, t0]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[1].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[2].size = 32

        instrs[2].operands[0].size = 32
        instrs[2].operands[1].size = 32
        instrs[2].operands[2].size = 32

        instrs[3].operands[0].size = 32
        instrs[3].operands[2].size = 32

        self._solver.reset()

        # add constrains
        mem = self._translator.get_memory()
        eax = self._solver.mkBitVec(32, "eax_0")
        off = BitVec(32, "#x%08x" % 0x1000)

        constraint = (mem[eax + off] != 42)

        if VERBOSE:
            print "constraint : %s" % constraint

        self._solver.add(constraint)

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            mem[eax + off] == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("t1")) != 42,
        ]

        if VERBOSE:
            print "[+] Constraints :"

            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("ebx"))
                print "    t0 : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("t0"))
                print "    t1 : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("t1"))
                print "    [eax + off] : 0x%08x" % self._solver.getvalue(mem[eax + off])

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_mul(self):
        if VERBOSE:
            print "\n[+] Test: test_mul"

        instrs = self._parser.parse([
            "mul [0x0, 0x1, t0]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[1].size = 32

        # TODO: Ver esto, el tam del output deberia ser 64
        instrs[0].operands[2].size = 32

        self._solver.reset()

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t0")) == 0,
            self._solver.mkBitVec(32, self._translator.get_init_name("t0")) != 0,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    t0 : 0x%08x" % self._solver.getvaluebyname(self._translator.get_curr_name("t0"))

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_sext_1(self):
        instr = self._parser.parse(["sext [WORD 0xffff, EMPTY, DWORD t1]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) != 0xffffffff,
        ]

        self._solver.add(constraints[0])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_sext_2(self):
        instr = self._parser.parse(["sext [WORD 0x7fff, EMPTY, DWORD t1]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) != 0x00007fff,
        ]

        self._solver.add(constraints[0])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_bsh_left_1(self):
        instr = self._parser.parse(["bsh [DWORD t1, DWORD 16, QWORD t2]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) == 0xffffffff,
            self._solver.mkBitVec(64, self._translator.get_curr_name("t2")) != 0x0000ffffffff0000,
        ]

        self._solver.add(constraints[0])
        self._solver.add(constraints[1])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_bsh_left_2(self):
        instr = self._parser.parse(["bsh [DWORD t1, DWORD 16, DWORD t2]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) == 0xffffffff,
            self._solver.mkBitVec(32, self._translator.get_curr_name("t2")) != 0xffff0000,
        ]

        self._solver.add(constraints[0])
        self._solver.add(constraints[1])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_bsh_left_3(self):
        instr = self._parser.parse(["bsh [DWORD t1, DWORD 16, WORD t2]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) == 0xffffffff,
            self._solver.mkBitVec(16, self._translator.get_curr_name("t2")) != 0x0000,
        ]

        self._solver.add(constraints[0])
        self._solver.add(constraints[1])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_bsh_right_1(self):
        instr = self._parser.parse(["bsh [DWORD t1, DWORD -16, QWORD t2]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) == 0xffffffff,
            self._solver.mkBitVec(64, self._translator.get_curr_name("t2")) != 0x000000000000ffff,
        ]

        self._solver.add(constraints[0])
        self._solver.add(constraints[1])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_bsh_right_2(self):
        instr = self._parser.parse(["bsh [DWORD t1, DWORD -16, DWORD t2]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) == 0xffffffff,
            self._solver.mkBitVec(32, self._translator.get_curr_name("t2")) != 0x0000ffff,
        ]

        self._solver.add(constraints[0])
        self._solver.add(constraints[1])

        self.assertEqual(self._solver.check(), 'unsat')

    def test_bsh_right_3(self):
        instr = self._parser.parse(["bsh [DWORD t1, DWORD -16, WORD t2]"])

        self._solver.reset()

        smt_expr = self._translator.translate(instr[0])

        self._solver.add(smt_expr[0])

        # add constrains
        constraints = [
            self._solver.mkBitVec(32, self._translator.get_curr_name("t1")) == 0xffffffff,
            self._solver.mkBitVec(16, self._translator.get_curr_name("t2")) != 0xffff,
        ]

        self._solver.add(constraints[0])
        self._solver.add(constraints[1])

        self.assertEqual(self._solver.check(), 'unsat')
Exemplo n.º 5
0
class ReilEmulatorTests(unittest.TestCase):

    def setUp(self):
        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._emulator = ReilEmulator(self._arch_info)

        self._asm_parser = X86Parser()
        self._reil_parser = ReilParser()

        self._translator = X86Translator()

    def test_add(self):
        asm_instrs  = self._asm_parser.parse("add eax, ebx")

        self.__set_address(0xdeadbeef, [asm_instrs])

        reil_instrs = self._translator.translate(asm_instrs)

        regs_initial = {
            "eax" : 0x1,
            "ebx" : 0x2,
        }

        regs_final, _ = self._emulator.execute_lite(
            reil_instrs,
            context=regs_initial
        )

        self.assertEqual(regs_final["eax"], 0x3)
        self.assertEqual(regs_final["ebx"], 0x2)

    def test_loop(self):
        # 0x08048060 : b8 00 00 00 00   mov eax,0x0
        # 0x08048065 : bb 0a 00 00 00   mov ebx,0xa
        # 0x0804806a : 83 c0 01         add eax,0x1
        # 0x0804806d : 83 eb 01         sub ebx,0x1
        # 0x08048070 : 83 fb 00         cmp ebx,0x0
        # 0x08048073 : 75 f5            jne 0x0804806a

        asm_instrs_str  = [(0x08048060, "mov eax,0x0", 5)]
        asm_instrs_str += [(0x08048065, "mov ebx,0xa", 5)]
        asm_instrs_str += [(0x0804806a, "add eax,0x1", 3)]
        asm_instrs_str += [(0x0804806d, "sub ebx,0x1", 3)]
        asm_instrs_str += [(0x08048070, "cmp ebx,0x0", 3)]
        asm_instrs_str += [(0x08048073, "jne 0x0804806a", 2)]

        asm_instrs = []

        for addr, asm, size in asm_instrs_str:
            asm_instr = self._asm_parser.parse(asm)
            asm_instr.address = addr
            asm_instr.size = size

            asm_instrs.append(asm_instr)

        reil_instrs = self.__translate(asm_instrs)

        regs_final, _ = self._emulator.execute(
            reil_instrs,
            start=0x08048060 << 8
        )

        self.assertEqual(regs_final["eax"], 0xa)
        self.assertEqual(regs_final["ebx"], 0x0)

    def test_mov(self):
        asm_instrs  = [self._asm_parser.parse("mov eax, 0xdeadbeef")]
        asm_instrs += [self._asm_parser.parse("mov al, 0x12")]
        asm_instrs += [self._asm_parser.parse("mov ah, 0x34")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs  = self._translator.translate(asm_instrs[0])
        reil_instrs += self._translator.translate(asm_instrs[1])
        reil_instrs += self._translator.translate(asm_instrs[2])

        regs_initial = {
            "eax" : 0xffffffff,
        }

        regs_final, _ = self._emulator.execute_lite(reil_instrs, context=regs_initial)

        self.assertEqual(regs_final["eax"], 0xdead3412)

    def test_pre_hanlder(self):
        def pre_hanlder(emulator, instruction, parameter):
            paramter.append(True)

        asm = ["mov eax, ebx"]

        x86_instrs = map(self._asm_parser.parse, asm)
        self.__set_address(0xdeadbeef, x86_instrs)
        reil_instrs = map(self._translator.translate, x86_instrs)

        paramter = []

        self._emulator.set_instruction_pre_handler(pre_hanlder, paramter)

        reil_ctx_out, reil_mem_out = self._emulator.execute_lite(
            reil_instrs[0]
        )

        self.assertTrue(len(paramter) > 0)

    def test_post_hanlder(self):
        def post_hanlder(emulator, instruction, parameter):
            paramter.append(True)

        asm = ["mov eax, ebx"]

        x86_instrs = map(self._asm_parser.parse, asm)
        self.__set_address(0xdeadbeef, x86_instrs)
        reil_instrs = map(self._translator.translate, x86_instrs)

        paramter = []

        self._emulator.set_instruction_post_handler(post_hanlder, paramter)

        reil_ctx_out, reil_mem_out = self._emulator.execute_lite(
            reil_instrs[0]
        )

        self.assertTrue(len(paramter) > 0)

    def test_zero_division_error_1(self):
        asm_instrs  = [self._asm_parser.parse("div ebx")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs  = self._translator.translate(asm_instrs[0])

        regs_initial = {
            "eax" : 0x2,
            "edx" : 0x2,
            "ebx" : 0x0,
        }

        self.assertRaises(ReilCpuZeroDivisionError, self._emulator.execute_lite, reil_instrs, context=regs_initial)

    def test_zero_division_error_2(self):
        instrs = ["mod [DWORD eax, DWORD ebx, DWORD t0]"]

        reil_instrs = self._reil_parser.parse(instrs)

        regs_initial = {
            "eax" : 0x2,
            "ebx" : 0x0,
        }

        self.assertRaises(ReilCpuZeroDivisionError, self._emulator.execute_lite, reil_instrs, context=regs_initial)

    def test_invalid_address_error_1(self):
        asm_instrs = [self._asm_parser.parse("jmp eax")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs = self.__translate(asm_instrs)

        regs_initial = {
            "eax" : 0xffffffff,
        }

        self.assertRaises(ReilCpuInvalidAddressError, self._emulator.execute, reil_instrs, start=0xdeadbeef << 8, registers=regs_initial)

    def test_invalid_address_error_2(self):
        asm_instrs = [self._asm_parser.parse("mov eax, 0xdeadbeef")]

        self.__set_address(0xdeadbeef, asm_instrs)

        reil_instrs = self.__translate(asm_instrs)

        regs_initial = {
            "eax" : 0xffffffff,
        }

        self.assertRaises(ReilCpuInvalidAddressError, self._emulator.execute, reil_instrs, start=0xdeadbef0 << 8, registers=regs_initial)

    # Auxiliary methods
    # ======================================================================== #
    def __set_address(self, address, asm_instrs):
        addr = address

        for asm_instr in asm_instrs:
            asm_instr.address = addr
            addr += 1

    def __translate(self, asm_instrs):
        instr_container = ReilContainer()

        asm_instr_last = None
        instr_seq_prev = None

        for asm_instr in asm_instrs:
            instr_seq = ReilSequence()

            for reil_instr in self._translator.translate(asm_instr):
                instr_seq.append(reil_instr)

            if instr_seq_prev:
                instr_seq_prev.next_sequence_address = instr_seq.address

            instr_container.add(instr_seq)

            instr_seq_prev = instr_seq

        if instr_seq_prev:
            if asm_instr_last:
                instr_seq_prev.next_sequence_address = (asm_instr_last.address + asm_instr_last.size) << 8

        # instr_container.dump()

        return instr_container
Exemplo n.º 6
0
class SmtTranslatorTests(unittest.TestCase):
    def setUp(self):
        self._address_size = 32
        self._parser = ReilParser()
        self._solver = SmtSolver()
        self._translator = SmtTranslator(self._solver, self._address_size)

        self._arch_info = X86ArchitectureInformation(ARCH_X86_MODE_32)

        self._translator.set_arch_alias_mapper(self._arch_info.alias_mapper)
        self._translator.set_arch_registers_size(
            self._arch_info.registers_size)

    # Arithmetic Instructions
    def test_translate_add_1(self):
        # Same size operands.
        instr = self._parser.parse(["add [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvadd t0_0 t1_0))")

    def test_translate_add_2(self):
        # Destination operand larger than source operands.
        instr = self._parser.parse(["add [BYTE t0, BYTE t1, WORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= t2_1 (bvadd ((_ zero_extend 8) t0_0) ((_ zero_extend 8) t1_0)))"
        )

    def test_translate_add_3(self):
        # Destination operand smaller than source operands.
        instr = self._parser.parse(["add [WORD t0, WORD t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value,
                         "(= t2_1 ((_ extract 7 0) (bvadd t0_0 t1_0)))")

    def test_translate_add_4(self):
        # Mixed source operands.
        instr = self._parser.parse(["add [BYTE t0, BYTE 0x12, WORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= t2_1 (bvadd ((_ zero_extend 8) t0_0) ((_ zero_extend 8) #x12)))"
        )

    def test_translate_sub(self):
        instr = self._parser.parse(["sub [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvsub t0_0 t1_0))")

    def test_translate_mul(self):
        instr = self._parser.parse(["mul [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvmul t0_0 t1_0))")

    def test_translate_div(self):
        instr = self._parser.parse(["div [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvudiv t0_0 t1_0))")

    def test_translate_mod(self):
        instr = self._parser.parse(["mod [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvurem t0_0 t1_0))")

    def test_translate_bsh(self):
        instr = self._parser.parse(["bsh [DWORD t0, DWORD t1, DWORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= t2_1 (ite (bvsge t1_0 #x00000000) (bvshl t0_0 t1_0) (bvlshr t0_0 (bvneg t1_0))))"
        )

    # Bitwise Instructions
    def test_translate_and(self):
        instr = self._parser.parse(["and [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvand t0_0 t1_0))")

    def test_translate_or(self):
        instr = self._parser.parse(["or [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvor t0_0 t1_0))")

    def test_translate_xor(self):
        instr = self._parser.parse(["xor [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvxor t0_0 t1_0))")

    # Data Transfer Instructions
    def test_translate_ldm(self):
        instr = self._parser.parse(["ldm [DWORD t0, empty, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value,
                         "(= (select MEM_0 (bvadd t0_0 #x00000000)) t2_1)")

    def test_translate_stm(self):
        instr = self._parser.parse(["stm [BYTE t0, empty, DWORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= MEM_1 (store MEM_0 (bvadd t2_0 #x00000000) t0_0))")

    def test_translate_str(self):
        instr = self._parser.parse(["str [BYTE t0, empty, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 t0_0)")

    # Conditional Instructions
    def test_translate_bisz(self):
        instr = self._parser.parse(["bisz [DWORD t0, empty, DWORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(
            form[0].value,
            "(= t2_1 (ite (= t0_0 #x00000000) #x00000001 #x00000000))")

    def test_translate_jcc(self):
        instr = self._parser.parse(["jcc [BIT t0, empty, DWORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(not (= t0_0 #b0))")

    # Other Instructions
    def test_translate_undef(self):
        instr = self._parser.parse(["undef [empty, empty, DWORD t2]"])[0]

        with self.assertRaises(Exception) as context:
            self._translator.translate(instr)

        self.assertTrue("Unsupported instruction : UNDEF" in context.exception)

    def test_translate_unkn(self):
        instr = self._parser.parse(["unkn [empty, empty, empty]"])[0]

        with self.assertRaises(Exception) as context:
            self._translator.translate(instr)

        self.assertTrue("Unsupported instruction : UNKN" in context.exception)

    def test_translate_nop(self):
        instr = self._parser.parse(["nop [empty, empty, empty]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 0)

    # Extensions
    def test_translate_sext(self):
        instr = self._parser.parse(["sext [BYTE t0, empty, WORD t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 ((_ sign_extend 8) t0_0))")

    def test_translate_sdiv(self):
        instr = self._parser.parse(["sdiv [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvsdiv t0_0 t1_0))")

    def test_translate_smod(self):
        instr = self._parser.parse(["smod [BYTE t0, BYTE t1, BYTE t2]"])[0]
        form = self._translator.translate(instr)

        self.assertEqual(len(form), 1)
        self.assertEqual(form[0].value, "(= t2_1 (bvsmod t0_0 t1_0))")
Exemplo n.º 7
0
class SmtTranslatorTests(unittest.TestCase):
    def setUp(self):
        self._address_size = 32
        self._parser = ReilParser()
        self._solver = SmtSolver()
        self._translator = SmtTranslator(self._solver, self._address_size)

    def test_add_reg_reg(self):
        if VERBOSE:
            print "\n[+] Test: test_add_reg_reg"

        # add eax, ebx
        instrs = self._parser.parse([
            "add [eax, ebx, t0]",
            "str [t0, e, eax]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[1].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[2].size = 32

        self._solver.reset()

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            self._solver.mkBitVec(32,
                                  self._translator.get_curr_name("eax")) == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("eax")) !=
            42,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("ebx"))

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_add_reg_mem(self):
        if VERBOSE:
            print "\n[+] Test: test_add_reg_mem"

        # add eax, [ebx]
        instrs = self._parser.parse([
            "ldm [ebx, EMPTY, t0]",
            "add [eax, t0, t1]",
            "str [t1, EMPTY, eax]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[1].size = 32
        instrs[1].operands[2].size = 32

        instrs[2].operands[0].size = 32
        instrs[2].operands[2].size = 32

        self._solver.reset()

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            self._solver.mkBitVec(32,
                                  self._translator.get_curr_name("eax")) == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("eax")) !=
            42,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("ebx"))

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_add_mem_reg(self):
        if VERBOSE:
            print "\n[+] Test: test_add_mem_reg"

        # add [eax], ebx
        instrs = self._parser.parse([
            "ldm [eax, EMPTY, t0]",
            "add [t0, ebx, t1]",
            "stm [t1, EMPTY, eax]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[1].size = 32
        instrs[1].operands[2].size = 32

        instrs[2].operands[0].size = 32
        instrs[2].operands[2].size = 32

        self._solver.reset()

        # add constrains
        mem = self._translator.get_memory()
        eax = self._solver.mkBitVec(32, "eax_0")

        constraint = (mem[eax] != 42)

        if VERBOSE:
            print "constraint : %s" % constraint

        self._solver.add(constraint)

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            mem[eax] == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("t0")) !=
            42,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("ebx"))
                print "    t0 : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("t0"))
                print "    [eax] : 0x%08x" % self._solver.getvalue(mem[eax])

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_add_mem_reg_2(self):
        if VERBOSE:
            print "\n[+] Test: test_add_mem_reg_2"

        # add [eax + 0x1000], ebx
        instrs = self._parser.parse([
            "add [eax, 0x1000, t0]",
            "ldm [t0, e, t1]",
            "add [t1, ebx, t2]",
            "stm [t2, e, t0]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[1].size = 32
        instrs[0].operands[2].size = 32

        instrs[1].operands[0].size = 32
        instrs[1].operands[2].size = 32

        instrs[2].operands[0].size = 32
        instrs[2].operands[1].size = 32
        instrs[2].operands[2].size = 32

        instrs[3].operands[0].size = 32
        instrs[3].operands[2].size = 32

        self._solver.reset()

        # add constrains
        mem = self._translator.get_memory()
        eax = self._solver.mkBitVec(32, "eax_0")
        off = BitVec(32, "#x%08x" % 0x1000)

        constraint = (mem[eax + off] != 42)

        if VERBOSE:
            print "constraint : %s" % constraint

        self._solver.add(constraint)

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            mem[eax + off] == 42,
            self._solver.mkBitVec(32, self._translator.get_init_name("t1")) !=
            42,
        ]

        if VERBOSE:
            print "[+] Constraints :"

            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    eax : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("eax"))
                print "    ebx : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("ebx"))
                print "    t0 : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("t0"))
                print "    t1 : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("t1"))
                print "    [eax + off] : 0x%08x" % self._solver.getvalue(
                    mem[eax + off])

            print "~" * 80

        self.assertEqual(is_sat, True)

    def test_mul(self):
        if VERBOSE:
            print "\n[+] Test: test_mul"

        instrs = self._parser.parse([
            "mul [0x0, 0x1, t0]",
        ])

        instrs[0].operands[0].size = 32
        instrs[0].operands[1].size = 32

        # TODO: Ver esto, el tam del output deberia ser 64
        instrs[0].operands[2].size = 32

        self._solver.reset()

        # translate instructions to formulae
        if VERBOSE:
            print "[+] Instructions:"
            for instr in instrs:
                trans = self._translator.translate(instr)

                if trans is not None:
                    self._solver.add(trans)

                print "    %-20s -> %-20s" % (instr, trans)

        # add constrains
        constraints = [
            self._solver.mkBitVec(32,
                                  self._translator.get_curr_name("t0")) == 0,
            self._solver.mkBitVec(32, self._translator.get_init_name("t0")) !=
            0,
        ]

        if VERBOSE:
            print "[+] Constraints :"
            for i, constr in enumerate(constraints):
                self._solver.add(constr)

                print "    %2d : %s" % (i, constr)

        # check satisfiability
        is_sat = self._solver.check() == 'sat'

        if VERBOSE:
            print "[+] Satisfiability : %s" % str(is_sat)

            if is_sat:
                print "    t0 : 0x%08x" % self._solver.getvaluebyname(
                    self._translator.get_curr_name("t0"))

            print "~" * 80

        self.assertEqual(is_sat, True)