Exemple #1
0
    def __init__(self, emu, memory, stack_size):
        self.stack_min = 0x00000000
        self.stack_max = 0xC0000000  # MAX MIPS 32-bit handles w/ qemu

        self.emu = emu
        self.scheduler = Scheduler(self, emu)
        self.memory = memory
        self.stack_size = stack_size
        self.logger = logging.getLogger(__name__)
        self._reset()
Exemple #2
0
class Threads:
    """
    Handles the execution of multiple threads on the same memory.

    This class should be manipulated through the process layer.
    """
    def __init__(self, emu, memory, stack_size, hook_manager):
        self.stack_min = 0x00000000
        self.stack_max = 0xC0000000  # MAX MIPS 32-bit handles w/ qemu

        self.emu = emu
        self.scheduler = Scheduler(self, emu)
        self.memory = memory
        self.stack_size = stack_size
        self._hook_manager = hook_manager
        self.logger = logging.getLogger(__name__)
        self._reset()

    def _reset(self):
        self.verbose = False
        self.thread_list = []
        self.current_thread = None
        self.dll_funcs = defaultdict(int)

        # A count of threads that are created, so we can create unique
        # names for threads
        self.thread_count = 0

    def __str__(self):
        s = "Thread Manager's Threads:\n"
        threads = self.get_all_threads()
        if len(threads) == 0:
            s += "  No threads present :(\n"
            return s

        for t in sorted(threads, key=lambda x: x.name):
            if self.is_current_thread(t):
                s += f"  *{t}\n"
            else:
                s += f"   {t}\n"

        return s

    @property
    def completed_threads(self):
        return [
            t for t in self.get_all_threads() if t.state == ThreadState.SUCCESS
        ]

    @property
    def failed_threads(self):
        return [
            t for t in self.get_all_threads() if t.state == ThreadState.FAILURE
        ]

    def get_active_threads(self) -> List[Thread]:
        """Returns all active threads"""
        self._check_paused_threads()
        return [t for t in self.get_all_threads() if t.is_active]

    def is_current_thread(self, t: Thread) -> bool:
        """
        Returns True if "t" is the currently running thread.

        Args:
            t: The thread to check.

        Returns:
            True if "t" is currently running.
        """
        return (self.current_thread is not None
                and self.current_thread.id == t.id)

    def kill_thread(self, tid: int) -> None:
        """
        Changes the state of the specified thread to KILLED

        Args:
            tid: The thread id of desired thread to kill
        """
        t = self.get_thread(tid)
        if t is None:
            self.logger.notice(f"No thread {tid:x} to kill")
            return
        if t.state != ThreadState.RUNNING:
            self.logger.info(f"Thread {tid:x}, is already in state {t.state}. "
                             f"Refusing to kill")
            return
        self.logger.info(f"Killing {tid:x}")
        if self.is_current_thread(t):
            self._inactivate_with_state(ThreadState.FAILURE)
        else:
            t.state = ThreadState.KILLED

    def as_current_thread(self, t: Thread, closure: Callable[[], Any]) -> Any:
        """
        Executes the closure as if "t" was the current thread.

        Args:
            t: The thread to set active while executing the closure
            closure: The function to execute

        Returns:
            The result of the closure.
        """
        if self.is_current_thread(t):
            return closure()

        current_thread_tid = None
        if self.current_thread is not None:
            current_thread_tid = self.current_thread.id
            self._swap_thread(t.id)

        ret_val = closure()

        if current_thread_tid is not None:
            self._swap_thread(current_thread_tid)
        return ret_val

    # TODO(V): Remove this function, put in timeleap
    def record_block(self, block_address):
        if self.current_thread is not None:
            self.current_thread.blocks_executed[block_address] += 1

    def block_seen_before(self, block_address):
        """
        Returns true if the block address has been seen in the current
        thread before
        """
        if self.current_thread is None:
            return None
        return block_address in self.current_thread.blocks_executed

    def num_unique_blocks(self, thread_name=None):
        """
        Returns the number of unique blocks for the given thread.
        Returns unique blocks across threads if no thread name is given
        """
        if thread_name is not None:
            t = self.get_thread_by_name(thread_name)
            return len(t.blocks_executed) if t is not None else 0

        threads = self.get_active_threads()
        return sum(len(t.blocks_executed) for t in threads)

    def executed_within_region(self, begin_addr, end_addr, thread_names=None):
        """
        Returns all block starts within the specified region, executed
        by the specified threads. If no thread_names are specified,
        checks all threads
        """
        threads = self.get_threads(thread_names)
        all_block_starts = []
        for t in threads:
            block_starts = [
                addr for addr in t.blocks_executed.keys()
                if begin_addr <= addr < end_addr
            ]
            all_block_starts.extend(block_starts)
        return all_block_starts

    def num_active_threads(self) -> int:
        """
        Returns the number of threads that are still executing.

        Returns:
            Number of threads that are still executing
        """
        return len(self.get_active_threads())

    # Ways a thread can fail
    #   * Inside an API, we can just say what the api is
    #   *   If it was a sys call, it may be useful to indicate this
    #   * Outside an api, this has to be an exception

    def fail_current_thread(self, fail_reason: Optional[str] = None) -> None:
        """
        Records the current thread as a failure and removes it from
        execution

        Args:
            fail_reason: Keeps track of why the thread failed. Used in
                debugging
        """
        self.current_thread.fail_reason = ("Unknown" if fail_reason is None
                                           else fail_reason)
        self.logger.error(
            "Thread %s failed: %s",
            self.current_thread.name,
            self.current_thread.fail_reason,
        )
        self._inactivate_with_state(ThreadState.FAILURE)

    def complete_current_thread(self) -> None:
        """
        Records the current thread as having completed successfully and
        removes it from execution
        """
        self.logger.success(
            f"Done executing thread {self.current_thread.name}")
        self._inactivate_with_state(ThreadState.SUCCESS)

    def pause_current_thread(self,
                             condition: Optional[Callable[[], bool]] = None
                             ) -> None:
        """
        Pauses the thread until the condition closure is checked and it
        evaluates to true. If no condition is supplied, the thread is
        paused indefinitely.

        Args:
            condition: Evaluated periodically, if it ever returns True,
                unpauses the thread

        """
        if condition is None:

            def condition():
                return False

        if condition():
            self.logger.notice("Pause condition is already true. "
                               "This is probably unintended.")
            return

        self.current_thread.pause_condition = condition

        self.logger.info(f"Pausing thread {self.current_thread.name}")
        self._inactivate_with_state(ThreadState.PAUSED)

    def _inactivate_with_state(self, thread_state):
        self.current_thread.state = thread_state
        self._swap(None)
        self.emu.setIP(0x30)
        self.scheduler.stop_and_exec("inactivate thread", lambda: True)

    def _check_paused_threads(self):
        """Checks whether any paused threads are ready to run"""
        for t in self.get_all_threads():
            if t.state != ThreadState.PAUSED:
                continue
            if t.pause_condition():
                self.logger.info(f"Thread {t.name} has been unpaused!")
                t.pause_condition = None
                t.state = ThreadState.RUNNING

    def new_thread(
        self,
        start_addr: int,
        tid: int,
        name: Optional[str] = None,
        priority: int = 0,
        stack_setup=None,
        module_path: str = "????",
        benign_code: bool = False,
    ) -> Thread:
        """
        Adds a thread which will run the thread_setup before starting.
        """
        # We want to ensure that we initialize the stack for this thread
        if name is None:
            name = f"child_thread_{self.thread_count}"
            self.thread_count += 1

        stack_bottom = self.memory.map_anywhere(
            self.stack_size,
            min_addr=self.stack_min,
            max_addr=self.stack_max,
            name=name,
            kind="stack",
            prot=UC_PROT_READ | UC_PROT_WRITE,
        )

        stack_base = util.align_down(stack_bottom + self.stack_size - 1,
                                     alignment=0x1000)

        new_thread = self.create_thread(
            start_addr,
            tid,
            stack_base,
            name,
            priority,
            stack_setup,
            module_path,
            benign_code,
        )

        self.thread_list.append(new_thread)

        return new_thread

    def create_thread(
        self,
        start_addr,
        tid,
        stack_base,
        name=None,
        priority=0,
        stack_setup=None,
        module_path="????",
        benign_code=False,
    ) -> Thread:
        temp_context = self.emu.context_save()
        self.emu.setIP(start_addr)
        new_thread_context = self.emu.context_save()

        self.logger.debug(f"  Adding thread {name} (priority {priority}) "
                          f"stack base at {stack_base:x}")
        new_thread = Thread(
            self,
            new_thread_context,
            stack_base,
            self.stack_size,
            tid,
            name,
            priority,
            parent=self.current_thread,
            module_path=module_path,
            benign_code=benign_code,
        )

        # TODO: Not a fan of the fact that we set the current thread to
        # be the new thread for the hooks. Need to find a way to allow
        # for the thread_create hooks to run, without having to
        # continually switch between the active thread and this one.

        self.emu.context_restore(temp_context)
        return new_thread

    def change_thread_priority(self, thread_name, new_priority):
        """ Change the priority of a thread"""
        t = self.get_thread_by_name(thread_name)
        if t is None:
            print("Unable to find thread %s" % thread_name)
            return
        t.priority = new_priority

    def get_all_threads(self) -> List[Thread]:
        """ Returns all threads, whether active or stopped"""
        return self.thread_list[:]

    def get_thread_by_name(self, name: str) -> Optional[Thread]:
        """ Returns the first thread with the given name """
        threads = self.get_threads([name])
        return threads[0] if len(threads) > 0 else None

    def get_thread(self, tid):
        for t in self.get_all_threads():
            if t.id == tid:
                return t
        return None

    def get_threads(self, names):
        """
        Returns threads that have a name within the given list. If names
        is None, returns all threads
        """
        if names is None:
            return self.get_all_threads()
        threads = []
        for t in self.get_all_threads():
            if t.name in names:
                threads.append(t)
        return threads

    def get_child_threads(self, tid: int) -> List[Thread]:
        """Returns all threads with the given parent name"""
        return [t for t in self.get_all_threads() if t.parent_id == tid]

    def swap_with_thread(self,
                         name: Optional[str] = None,
                         tid: Optional[int] = None) -> None:
        """
        Swaps the current thread with the first thread with the given
        name or thread id. Keep in mind, this will override the priority
        given to threads.
        You can only specify one of name or tid.

        Args:
            name: If specified, finds a thread by the name
            tid: If specified, finds a thread with that thread id
        """
        if name is None and tid is None:
            raise ThreadException("Must specify at least one of name/tid")
        if name is not None and tid is not None:
            raise ThreadException("May only specify one of name/tid")
        if name is not None:
            t = self.get_thread_by_name(name)
            if t is None:
                raise ThreadException(f"No thread named {name} exists.")
        if tid is not None:
            t = self.get_thread(tid)
            if t is None:
                raise InvalidTidException(tid)

        self._swap(t)

    def swap_with_next_thread(self) -> None:
        """
        Swaps the current thread with the next thread to execute. This
        respects priority, and will not swap if there is no thread of
        equal or greater priority
        """
        self._check_paused_threads()
        t = self._next()
        if t is None:
            self.logger.spam("Can't swap with thread, no other threads")

        self._swap(t)

    def _swap_thread(self, tid=None) -> None:
        """
        Internal function that swaps the current thread with the thread
        of the specified tid or the next available thread, without invoking
        thread swap hooks. This function is intended to be used by functions
        that need to temporarily swap threads internally.
        """
        if tid is None:
            self._check_paused_threads()
            t = self._next()
        else:
            t = self.get_thread(tid)
        if self.current_thread is not None:
            self.current_thread.save_context()
        self._load(t)

    def _swap(self, thread):
        """
        Swaps the currently executing thread with the specified thread
        in the emulator
        """
        old_thread = self.current_thread
        if old_thread is not None:
            old_thread.save_context()
        self._load(thread)
        for hook in self._hook_manager._get_hooks(HookType.THREAD.SWAP):
            hook(old_thread)

    def _load(self, thread):
        """ Loads the specified thread into the emulator """
        if thread is None:
            self.current_thread = None
            return
        if not thread.is_active:
            self.logger.error(
                f"Loading a thread with inactive state {thread.state}")
        self.emu.context_restore(thread.context)
        self.current_thread = thread
        self.logger.verbose(
            "Loaded thread {0}, starting at {1:x}, stack at {2:x}".format(
                thread.name, self.emu.getIP(), self.emu.getSP()))
        self.emu.setIP(self.emu.getIP())

    def _next(self, tid=None):
        """Returns the next thread to be scheduled."""
        active_threads = sorted(self.get_active_threads())
        if len(active_threads) == 0:
            return None
        next_thread = active_threads[0]
        self._send_to_back(next_thread.id)
        return next_thread

    def _send_to_back(self, tid):
        """
        Sends this tid back to the end of the list. Used for scheduling.
        """
        for i, t in enumerate(self.thread_list):
            if t.id == tid:
                self.thread_list.append(self.thread_list.pop(i))
                return
        raise InvalidTidException(tid)

    regs_to_save = (
        "eax",
        "ebp",
        "ebx",
        "ecx",
        "edi",
        "edx",
        "flags",
        "eip",
        "esi",
        "esp",
    )

    def _save_state(self):
        def _serialize_thread(thread):
            if thread is None:
                return None
            d = thread.__dict__.copy()
            del d["context"]  # Can't pickle, must be removed from dict
            self.emu.context_restore(thread.context)
            return (d, [self.emu.get_reg(reg) for reg in self.regs_to_save])

        if self.current_thread is not None:
            self.current_thread.save_context()

        context = {
            # Must be done first, since it is current loaded
            "current_thread_tid":
            self.current_thread.id
            if self.current_thread is not None else None,
            "thread_list": [_serialize_thread(t) for t in self.thread_list],
            "thread_count":
            self.thread_count,
        }

        # Restore the current thread's context
        if self.current_thread is not None:
            self.emu.context_restore(self.current_thread.context)
        return context

    def _load_state(self, data):
        self._reset()

        def _deserialize_thread(data):
            if data is None:
                return None
            (thread_dict, reg_vals) = data
            # Unsure if you need deepcopy, but you definitely need to
            # make sure that you are not linking the state data and the
            # thread_manager.
            thread_dict = thread_dict.copy()
            for val, reg in zip(reg_vals, self.regs_to_save):
                self.emu.set_reg(reg, val)
            thread_dict["context"] = self.emu.context_save()
            # Get a thread, we will set attributes manually.
            t = Thread(None, None, None, None, None)
            t.__dict__ = thread_dict
            return t

        self.thread_list = [
            _deserialize_thread(d) for d in data["thread_list"]
        ]
        # End with loading the current_thread, so that execution state
        # is ready to go.
        current_thread = self.get_thread(data["current_thread_tid"])
        self._load(current_thread)
        self.thread_count = data["thread_count"]