Example #1
0
    def build_dataset(self):
        """Process the provenance graph to extract histogram data."""
        super(CapSizeCreationPlot, self).build_dataset()

        # indexes in the vmmap and in the norm_histograms are
        # the same.
        vm_entries = list(self.vmmap)
        vm_ranges = [Range(v.start, v.end) for v in self.vmmap]
        hist_data = [[] for _ in range(len(vm_entries))]

        progress = ProgressPrinter(self.dataset.num_vertices(),
                                   desc="Sorting capability references")
        logger.debug("Vm ranges %s", vm_ranges)
        for node in self.dataset.vertices():
            data = self.dataset.vp.data[node]
            for idx, r in enumerate(vm_ranges):
                if Range(data.cap.base, data.cap.bound) in r:
                    hist_data[idx].append(data.cap.length)
            progress.advance()
        progress.finish()

        for vm_entry,data in zip(vm_entries, hist_data):
            logger.debug("hist entry len %d", len(data))
            if len(data) == 0:
                continue
            # the bin size is logarithmic
            data = np.log2(data)
            h, b = np.histogram(data, bins=self.n_bins)
            # append histograms to the dataframe
            # self.hist_sources.append(vm_entry)
            # new_index = len(self.abs_histogram.index)
            self.abs_histogram.loc[vm_entry] = h
            self.norm_histogram.loc[vm_entry] = h / np.sum(h)
Example #2
0
    def build_dataset(self):
        """Process the provenance graph to extract histogram data."""
        super(CapSizeDerefPlot, self).build_dataset()

        # indexes in the vmmap and in the norm_histograms are
        # the same.
        vm_entries = list(self.vmmap)
        vm_ranges = [Range(v.start, v.end) for v in self.vmmap]
        hist_data = [[] for _ in range(len(vm_ranges))]

        progress = ProgressPrinter(self.dataset.num_vertices(),
                                   desc="Sorting capability references")
        for node in self.dataset.vertices():
            data = self.dataset.vp.data[node]
            # iterate over every dereference of the node
            for addr in data.deref["addr"]:
                # check in which vm-entry the address is
                for idx, r in enumerate(vm_ranges):
                    if addr in r:
                        hist_data[idx].append(data.cap.length)
                        break
            progress.advance()
        progress.finish()

        for vm_entry,data in zip(vm_entries, hist_data):
            if len(data) == 0:
                continue
            # the bin size is logarithmic
            data = np.log2(data)
            h, b = np.histogram(data, bins=self.n_bins)
            # append histogram to the dataframes
            # self.hist_sources.append(vm_entry)
            # new_index = len(self.abs_histogram.index)
            self.abs_histogram.loc[vm_entry] = h
            self.norm_histogram.loc[vm_entry] = h / np.sum(h)
Example #3
0
 def build_dataset(self):
     """
     For each capability with exec permissions, merge its
     store map to a common dictionary.
     The common dictionary is then used for the plot.
     """
     super().build_dataset()
     progress = ProgressPrinter(
         self.dataset.num_vertices(),
         desc="Extract executable cap memory locations")
     for node in self.dataset.vertices():
         node_data = self.dataset.vp.data[node]
         if node_data.cap.has_perm(CheriCapPerm.EXEC):
             for addr in node_data.address.values():
                 self.range_builder.inspect(addr)
             self.store_addr_map.update(node_data.address)
         progress.advance()
     progress.finish()
Example #4
0
    def _prepare_patches(self):
        """
        Prepare the patches and address ranges in the patch_builder
        and range_builder.
        """

        dataset_progress = ProgressPrinter(self.dataset.num_vertices(),
                                           desc="Adding nodes")
        for item in self.dataset.vertices():
            data = self.dataset.vp.data[item]
            self.patch_builder.inspect(data)
            self.range_builder.inspect(data)
            dataset_progress.advance()
        dataset_progress.finish()

        if self.vmmap:
            logger.debug("Generate mmap regions")
            for vme in self.vmmap:
                self.vmmap_patch_builder.inspect(vme)
                self.range_builder.inspect_range(Range(vme.start, vme.end))
Example #5
0
    def build_dataset(self):
        """
        Build the provenance tree
        """
        if self.caching:
            try:
                logger.debug("Load cached provenance graph")
                self.dataset = load_graph(self._get_cache_file())
            except IOError:
                self.parser.parse()
                self.dataset.save(self._get_cache_file())
        else:
            self.parser.parse()

        num_nodes = self.dataset.num_vertices()
        logger.debug("Total nodes %d", num_nodes)
        vertex_mask = self.dataset.new_vertex_property("bool")

        progress = ProgressPrinter(num_nodes, desc="Search kernel nodes")
        for node in self.dataset.vertices():
            # remove null capabilities
            # remove operations in kernel mode
            vertex_data = self.dataset.vp.data
            node_data = vertex_data[node]

            if ((node_data.pc != 0 and node_data.is_kernel) or
                (node_data.cap.length == 0 and node_data.cap.base == 0)):
                vertex_mask[node] = True
            progress.advance()
        progress.finish()

        self.dataset.set_vertex_filter(vertex_mask, inverted=True)
        vertex_mask = self.dataset.copy_property(vertex_mask)

        num_nodes = self.dataset.num_vertices()
        logger.debug("Filtered kernel nodes, remaining %d", num_nodes)
        progress = ProgressPrinter(
            num_nodes, desc="Merge (cfromptr + csetbounds) sequences")

        for node in self.dataset.vertices():
            progress.advance()
            # merge cfromptr -> csetbounds subtrees
            num_parents = node.in_degree()
            if num_parents == 0:
                # root node
                continue
            elif num_parents > 1:
                logger.error("Found node with more than a single parent %s", node)
                raise RuntimeError("Too many parents for a node")

            parent = next(node.in_neighbours())
            parent_data = self.dataset.vp.data[parent]
            node_data = self.dataset.vp.data[node]
            if (parent_data.origin == CheriNodeOrigin.FROMPTR and
                node_data.origin == CheriNodeOrigin.SETBOUNDS):
                # the child must be unique to avoid complex logic
                # when merging, it may be desirable to do so with
                # more complex traces
                node_data.origin = CheriNodeOrigin.PTR_SETBOUNDS
                if parent.in_degree() == 1:
                    next_parent = next(parent.in_neighbours())
                    vertex_mask[parent] = True
                    self.dataset.add_edge(next_parent, node)
                elif parent.in_degree() == 0:
                    vertex_mask[parent] = True
                else:
                    logger.error("Found node with more than a single parent %s",
                                 parent)
                    raise RuntimeError("Too many parents for a node")
        progress.finish()

        self.dataset.set_vertex_filter(vertex_mask, inverted=True)
        vertex_mask = self.dataset.copy_property(vertex_mask)

        num_nodes = self.dataset.num_vertices()
        logger.debug("Merged (cfromptr + csetbounds), remaining %d", num_nodes)
        progress = ProgressPrinter(num_nodes, desc="Find short-lived cfromptr")

        for node in self.dataset.vertices():
            progress.advance()
            node_data = self.dataset.vp.data[node]

            if node_data.origin == CheriNodeOrigin.FROMPTR:
                vertex_mask[node] = True
            # if (node_data.origin == CheriNodeOrigin.FROMPTR and
            #     len(node_data.address) == 0 and
            #     len(node_data.deref["load"]) == 0 and
            #     len(node_data.deref["load"]) == 0):
            #     # remove cfromptr that are never stored or used in
            #     # a dereference
            #     remove_list.append(node)
        progress.finish()

        self.dataset.set_vertex_filter(vertex_mask, inverted=True)
Example #6
0
    def plot(self):
        graph_size = self.dataset.num_vertices()
        # (addr, num_allocations)
        addresses = {}
        page_use = {}
        page_size = 2**12

        # address reuse metric
        # num_allocations vs address
        # linearly and in 4k chunks
        tree_progress = ProgressPrinter(graph_size, desc="Fetching addresses")
        for node in self.dataset.vertices():
            data = self.dataset.vp.data[node]
            for time, addr in data.address.items():
                try:
                    addresses[addr] += 1
                except KeyError:
                    addresses[addr] = 1
                page_addr = addr & (~0xfff)
                try:
                    page_use[page_addr] += 1
                except KeyError:
                    page_use[page_addr] = 1
            tree_progress.advance()
        tree_progress.finish()

        # time vs address
        # address working set over time

        fig = plt.figure(figsize=(15, 10))
        ax = fig.add_axes([
            0.05,
            0.15,
            0.9,
            0.80,
        ],
                          projection="custom_addrspace")
        ax.set_ylabel("Number of pointers stored")
        ax.set_xlabel("Virtual address")
        ax.set_yscale("log")
        # ax.set_ylim(0, )
        data = np.array(sorted(page_use.items(), key=lambda i: i[0]))
        # ignore empty address-space chunks
        prev_addr = data[0]
        omit_ranges = []
        first_tick = data[0][0]
        ticks = [first_tick]
        labels = ["0x%x" % int(first_tick)]
        for addr in data:
            logger.debug("DATA 0x%x (%d)", int(addr[0]), addr[0])
            if addr[0] - prev_addr[0] > 2**12:
                omit_ranges.append([prev_addr[0] + page_size, addr[0]])
                ticks.append(addr[0])
                labels.append("0x%x" % int(addr[0]))
            prev_addr = addr
        ax.set_omit_ranges(omit_ranges)
        # ax.set_xticks(ticks)
        # ax.set_xticklabels(labels, rotation="vertical")
        ax.set_xlim(first_tick - page_size, data[-1][0] + page_size)
        ax.vlines(data[:, 0], [1] * len(data[:, 0]), data[:, 1], color="b")

        fig.savefig(self._get_plot_file())
        return fig
Example #7
0
    def _extract_ranges(self):
        """
        Extract ranges from the provenance graph

        XXX for now do the prototype data manipulation here
        with a naive RangeSet object later we may want to
        move it somewhere else with more dedicated solution
        using interval trees
        """
        dataset_progress = ProgressPrinter(
            self.dataset.num_vertices(), desc="Extract frequency of reference")
        range_set = RangeSet()
        for vertex in self.dataset.vertices():
            node = self.dataset.vp.data[vertex]
            logger.debug("Inspect node %s", node)
            r_node = self.DataRange(node.cap.base,
                                    node.cap.base + node.cap.length)
            node_set = RangeSet([r_node])
            # erode r_node until it is fully merged in the range_set
            # the node_set holds intermediate ranges remaining to merge
            while len(node_set):
                logger.debug("merging node")
                # pop first range from rangeset and try to merge it
                r_current = node_set.pop(0)
                # get first overlapping range
                r_overlap = range_set.pop_overlap_range(r_current)
                if r_overlap == None:
                    # no overlap occurred, just add it to the rangeset
                    range_set.append(r_current)
                    logger.debug("-> no overlap")
                    continue
                logger.debug("picked current %s", r_current)
                logger.debug("picked overlap %s", r_overlap)
                # merge r_current and r_overlap data and push any remaining
                # part of r_current back in node_set
                #
                # r_same: referenced count does not change
                # r_inc: referenced count incremented
                # r_rest: pushed back to node_set for later evaluation
                if r_overlap.start <= r_current.start:
                    logger.debug("overlap before current")
                    # 2 possible layouts:
                    #          |------ r_current -------|
                    # |------ r_overlap -----|
                    # |-r_same-|-- r_inc ----|- r_rest -|
                    #
                    # |--------------- r_overlap --------------|
                    # |-r_same-|-------- r_inc ---------|r_same|
                    r_same, other = r_overlap.split(r_current.start)
                    if r_same.size > 0:
                        range_set.append(r_same)

                    if r_current.end >= r_overlap.end:
                        # other is the remaining part of r_overlap
                        # which falls all in r_current, so
                        # r_inc = other
                        other.num_references += 1
                        range_set.append(other)
                        # r_rest must be computed from the end
                        # of r_overlap
                        _, r_rest = r_current.split(r_overlap.end)
                        if r_rest.size > 0:
                            node_set.append(r_rest)
                    else:
                        # other does not fall all in r_current so
                        # split other in r_inc and r_same
                        # r_current is not pushed back because it
                        # was fully covered by r_overlap
                        r_inc, r_same = other.split(r_current.end)
                        r_inc.num_references += 1
                        range_set.append(r_inc)
                        range_set.append(r_same)
                else:
                    logger.debug("current before overlap")
                    # 2 possible layouts:
                    # |------ r_current ---------|
                    #          |------ r_overlap ---------|
                    # |-r_rest-|-- r_inc --------| r_same |
                    #
                    # |------ r_current --------------|
                    #        |--- r_overlap ---|
                    # |r_rest|----- r_inc -----|r_rest|
                    r_rest, other = r_current.split(r_overlap.start)
                    if r_rest.size > 0:
                        node_set.append(r_rest)

                    if r_current.end >= r_overlap.end:
                        # other is the remaining part of r_current
                        # which completely covers r_overlap so
                        # split other in r_inc and r_rest
                        r_inc, r_rest = other.split(r_overlap.end)
                        r_inc.num_references += r_overlap.num_references
                        range_set.append(r_inc)
                        if r_rest.size > 0:
                            node_set.append(r_rest)
                    else:
                        # other does not cover all r_overlap
                        # so r_inc = other and the remaining
                        # part of r_overlap is r_same
                        other.num_references += r_overlap.num_references
                        range_set.append(other)
                        _, r_same = r_overlap.split(r_current.end)
                        range_set.append(r_same)
                logger.debug("merge loop out Range set step %s", range_set)
                logger.debug("merge loop out Node set step %s", node_set)
            logger.debug("Range set step %s", range_set)
            logger.debug("Node set step %s", node_set)
            dataset_progress.advance()
        dataset_progress.finish()
        logger.debug("Range set %s", range_set)
        self.range_set = range_set
Example #8
0
class CallbackTraceParser(TraceParser):
    """
    Trace parser that provides help to filter
    and normalize instructions.

    This class performs the filtering of instructions
    that are interesting to the parser and calls the appropriate
    callback if it is defined.
    Callback methods must start with "scan_" followed by the opcode
    or instruction class (e.g. scan_ld will be invoked every time an
    "ld" instruction is found, scan_cap_load will be invoked every time
    a load or store through a capability is found).
    The callback must have the follwing signature:
    scan_<name>(inst, entry, regs, last_regs, idx).

    Valid instruction class names are:

    * all: all instructions
    * cap: all capability instructions
    * cap_load: all capability load
    * cap_store: all capability store
    * cap_arith: all capability pointer manipulation
    * cap_bound: all capability bound modification
    * cap_cast: all conversions from and to capability pointers
    * cap_cpreg: all manipulations of ddc, kdc, epcc, kcc
    * cap_other: all capability instructions that do not fall in
    the previous "cap_" classes
    """
    def __init__(self, dataset, trace_path, **kwargs):
        super(CallbackTraceParser, self).__init__(trace_path, **kwargs)

        self.dataset = dataset
        """The dataset where the parsed data will be stored"""

        self.progress = ProgressPrinter(len(self),
                                        desc="Scanning trace %s" % trace_path)
        """Progress object to display feedback to the user"""

        self._last_regs = None
        """Snapshot of the registers of the previous instruction"""

        self._dis = pct.disassembler()
        """Disassembler"""

        # Enumerate the callbacks at creation time to save
        # time during scanning
        self._callbacks = {}

        # for each opcode we may be interested in, check if there is
        # one or more callbacks to call, if so these will be stored
        # in _callbacks[<opcode>] so that the _get_callbacks function
        # can retrieve them in ~O(1)
        for attr in dir(self):
            method = getattr(self, attr)
            if (not attr.startswith("scan_") or not callable(method)):
                continue
            instr_name = attr[5:]
            for iclass in Instruction.IClass:
                if instr_name == iclass.value:
                    # add the iclass callback for all the
                    # instructions in such class
                    opcodes = Instruction.iclass_map.get(iclass, [])
                    for opcode in opcodes:
                        if opcode in self._callbacks:
                            self._callbacks[opcode].append(method)
                        else:
                            self._callbacks[opcode] = [method]
                    break
            else:
                if instr_name in self._callbacks:
                    self._callbacks[instr_name] += [method]
                else:
                    self._callbacks[instr_name] = [method]

        logger.debug("Loaded callbacks for CallbackTraceParser %s",
                     self._callbacks)

    def _get_callbacks(self, inst):
        """
        Return a list of callback methods that should be called to
        parse this instruction

        :param inst: instruction object for the current instruction
        :type inst: :class:`.Instruction`
        :return: list of methods to be called
        :rtype: list of callables
        """
        # try to get the callback for all instructions, if any
        callbacks = list(self._callbacks.get("all", []))
        # the <all> callback should be the last one executed
        callbacks = self._callbacks.get(inst.opcode, []) + callbacks
        return callbacks

    def _parse_exception(self, entry, regs, disasm, idx):
        """
        Callback invoked when an instruction could not be parsed
        XXX make this debug because the mul instruction always fails
        and it is too verbose but should report it as a warning/error
        """
        logger.debug("Error parsing instruction #%d pc:0x%x: %s raw: 0x%x",
                     entry.cycles, entry.pc, disasm.name, entry.inst)

    def parse(self, start=None, end=None, direction=0):
        """
        Parse the trace

        For each trace entry a callback is invoked, some classes
        of instructions cause the callback for the group to be called,
        e.g. scan_cap_load is called whenever a load from memory through
        a capability is found.

        Each instruction opcode can have a callback in the form
        scan_<opcode>.

        :param start: index of the first trace entry to scan
        :type start: int
        :param end: index of the last trace entry to scan
        :type end: int
        :param direction: scan direction (forward = 0, backward=1)
        :type direction: int
        """

        if start is None:
            start = 0
        if end is None:
            end = len(self)
        # fast progress processing, calling progress.advance() in each
        # _scan call is too expensive
        progress_points = list(range(start, end, int((end - start) / 100) + 1))
        progress_points.append(end)

        def _scan(entry, regs, idx):
            if idx >= progress_points[0]:
                progress_points.pop(0)
                self.progress.advance(to=idx)
            disasm = self._dis.disassemble(entry.inst)
            try:
                if self._last_regs is None:
                    self._last_regs = regs
                inst = Instruction(disasm, entry, regs, self._last_regs)
            except Exception as e:
                self._parse_exception(entry, regs, disasm, idx)
                return False

            ret = False

            try:
                for cbk in self._get_callbacks(inst):
                    ret |= cbk(inst, entry, regs, self._last_regs, idx)
                    if ret:
                        break
            except Exception as e:
                logger.error("Error in callback %s: %s", cbk, e)
                raise

            self._last_regs = regs
            return ret

        self.trace.scan(_scan, start, end, direction)
        self.progress.finish()