Esempio n. 1
0
    def __init__(self, txt_trace, *args, pc_only=False, **kwargs):
        super().__init__(*args, **kwargs)

        self.pc_only = pc_only

        self.progress = ProgressPrinter(len(self), "Scan traces")

        # txt trace perser state machine
        self.txt_parse_state = State.S_INSTR
        self.txt_trace = open(txt_trace, "r")
        # skip lines from the txt trace until the first
        # instruction
        self._skiplines(inst_only=True)
        self.txt_parse_state = State.S_INSTR
Esempio n. 2
0
    def __init__(self, dataset, trace_path, **kwargs):
        super(CallbackTraceParser, self).__init__(trace_path, **kwargs)

        self.dataset = dataset
        """The dataset where the parsed data will be stored"""

        self.progress = ProgressPrinter(len(self),
                                        desc="Scanning trace %s" % trace_path)
        """Progress object to display feedback to the user"""

        self._last_regs = None
        """Snapshot of the registers of the previous instruction"""

        self._dis = pct.disassembler()
        """Disassembler"""

        # Enumerate the callbacks at creation time to save
        # time during scanning
        self._callbacks = {}

        # for each opcode we may be interested in, check if there is
        # one or more callbacks to call, if so these will be stored
        # in _callbacks[<opcode>] so that the _get_callbacks function
        # can retrieve them in ~O(1)
        for attr in dir(self):
            method = getattr(self, attr)
            if (not attr.startswith("scan_") or not callable(method)):
                continue
            instr_name = attr[5:]
            for iclass in Instruction.IClass:
                if instr_name == iclass.value:
                    # add the iclass callback for all the
                    # instructions in such class
                    opcodes = Instruction.iclass_map.get(iclass, [])
                    for opcode in opcodes:
                        if opcode in self._callbacks:
                            self._callbacks[opcode].append(method)
                        else:
                            self._callbacks[opcode] = [method]
                    break
            else:
                if instr_name in self._callbacks:
                    self._callbacks[instr_name] += [method]
                else:
                    self._callbacks[instr_name] = [method]

        logger.debug("Loaded callbacks for CallbackTraceParser %s",
                     self._callbacks)
Esempio n. 3
0
    def build_dataset(self):
        """Process the provenance graph to extract histogram data."""
        super(CapSizeDerefPlot, self).build_dataset()

        # indexes in the vmmap and in the norm_histograms are
        # the same.
        vm_entries = list(self.vmmap)
        vm_ranges = [Range(v.start, v.end) for v in self.vmmap]
        hist_data = [[] for _ in range(len(vm_ranges))]

        progress = ProgressPrinter(self.dataset.num_vertices(),
                                   desc="Sorting capability references")
        for node in self.dataset.vertices():
            data = self.dataset.vp.data[node]
            # iterate over every dereference of the node
            for addr in data.deref["addr"]:
                # check in which vm-entry the address is
                for idx, r in enumerate(vm_ranges):
                    if addr in r:
                        hist_data[idx].append(data.cap.length)
                        break
            progress.advance()
        progress.finish()

        for vm_entry,data in zip(vm_entries, hist_data):
            if len(data) == 0:
                continue
            # the bin size is logarithmic
            data = np.log2(data)
            h, b = np.histogram(data, bins=self.n_bins)
            # append histogram to the dataframes
            # self.hist_sources.append(vm_entry)
            # new_index = len(self.abs_histogram.index)
            self.abs_histogram.loc[vm_entry] = h
            self.norm_histogram.loc[vm_entry] = h / np.sum(h)
Esempio n. 4
0
    def build_dataset(self):
        """Process the provenance graph to extract histogram data."""
        super(CapSizeCreationPlot, self).build_dataset()

        # indexes in the vmmap and in the norm_histograms are
        # the same.
        vm_entries = list(self.vmmap)
        vm_ranges = [Range(v.start, v.end) for v in self.vmmap]
        hist_data = [[] for _ in range(len(vm_entries))]

        progress = ProgressPrinter(self.dataset.num_vertices(),
                                   desc="Sorting capability references")
        logger.debug("Vm ranges %s", vm_ranges)
        for node in self.dataset.vertices():
            data = self.dataset.vp.data[node]
            for idx, r in enumerate(vm_ranges):
                if Range(data.cap.base, data.cap.bound) in r:
                    hist_data[idx].append(data.cap.length)
            progress.advance()
        progress.finish()

        for vm_entry,data in zip(vm_entries, hist_data):
            logger.debug("hist entry len %d", len(data))
            if len(data) == 0:
                continue
            # the bin size is logarithmic
            data = np.log2(data)
            h, b = np.histogram(data, bins=self.n_bins)
            # append histograms to the dataframe
            # self.hist_sources.append(vm_entry)
            # new_index = len(self.abs_histogram.index)
            self.abs_histogram.loc[vm_entry] = h
            self.norm_histogram.loc[vm_entry] = h / np.sum(h)
Esempio n. 5
0
 def build_dataset(self):
     """
     For each capability with exec permissions, merge its
     store map to a common dictionary.
     The common dictionary is then used for the plot.
     """
     super().build_dataset()
     progress = ProgressPrinter(
         self.dataset.num_vertices(),
         desc="Extract executable cap memory locations")
     for node in self.dataset.vertices():
         node_data = self.dataset.vp.data[node]
         if node_data.cap.has_perm(CheriCapPerm.EXEC):
             for addr in node_data.address.values():
                 self.range_builder.inspect(addr)
             self.store_addr_map.update(node_data.address)
         progress.advance()
     progress.finish()
Esempio n. 6
0
    def _prepare_patches(self):
        """
        Prepare the patches and address ranges in the patch_builder
        and range_builder.
        """

        dataset_progress = ProgressPrinter(self.dataset.num_vertices(),
                                           desc="Adding nodes")
        for item in self.dataset.vertices():
            data = self.dataset.vp.data[item]
            self.patch_builder.inspect(data)
            self.range_builder.inspect(data)
            dataset_progress.advance()
        dataset_progress.finish()

        if self.vmmap:
            logger.debug("Generate mmap regions")
            for vme in self.vmmap:
                self.vmmap_patch_builder.inspect(vme)
                self.range_builder.inspect_range(Range(vme.start, vme.end))
Esempio n. 7
0
    def build_dataset(self):
        """
        Build the provenance tree
        """
        if self.caching:
            try:
                logger.debug("Load cached provenance graph")
                self.dataset = load_graph(self._get_cache_file())
            except IOError:
                self.parser.parse()
                self.dataset.save(self._get_cache_file())
        else:
            self.parser.parse()

        num_nodes = self.dataset.num_vertices()
        logger.debug("Total nodes %d", num_nodes)
        vertex_mask = self.dataset.new_vertex_property("bool")

        progress = ProgressPrinter(num_nodes, desc="Search kernel nodes")
        for node in self.dataset.vertices():
            # remove null capabilities
            # remove operations in kernel mode
            vertex_data = self.dataset.vp.data
            node_data = vertex_data[node]

            if ((node_data.pc != 0 and node_data.is_kernel) or
                (node_data.cap.length == 0 and node_data.cap.base == 0)):
                vertex_mask[node] = True
            progress.advance()
        progress.finish()

        self.dataset.set_vertex_filter(vertex_mask, inverted=True)
        vertex_mask = self.dataset.copy_property(vertex_mask)

        num_nodes = self.dataset.num_vertices()
        logger.debug("Filtered kernel nodes, remaining %d", num_nodes)
        progress = ProgressPrinter(
            num_nodes, desc="Merge (cfromptr + csetbounds) sequences")

        for node in self.dataset.vertices():
            progress.advance()
            # merge cfromptr -> csetbounds subtrees
            num_parents = node.in_degree()
            if num_parents == 0:
                # root node
                continue
            elif num_parents > 1:
                logger.error("Found node with more than a single parent %s", node)
                raise RuntimeError("Too many parents for a node")

            parent = next(node.in_neighbours())
            parent_data = self.dataset.vp.data[parent]
            node_data = self.dataset.vp.data[node]
            if (parent_data.origin == CheriNodeOrigin.FROMPTR and
                node_data.origin == CheriNodeOrigin.SETBOUNDS):
                # the child must be unique to avoid complex logic
                # when merging, it may be desirable to do so with
                # more complex traces
                node_data.origin = CheriNodeOrigin.PTR_SETBOUNDS
                if parent.in_degree() == 1:
                    next_parent = next(parent.in_neighbours())
                    vertex_mask[parent] = True
                    self.dataset.add_edge(next_parent, node)
                elif parent.in_degree() == 0:
                    vertex_mask[parent] = True
                else:
                    logger.error("Found node with more than a single parent %s",
                                 parent)
                    raise RuntimeError("Too many parents for a node")
        progress.finish()

        self.dataset.set_vertex_filter(vertex_mask, inverted=True)
        vertex_mask = self.dataset.copy_property(vertex_mask)

        num_nodes = self.dataset.num_vertices()
        logger.debug("Merged (cfromptr + csetbounds), remaining %d", num_nodes)
        progress = ProgressPrinter(num_nodes, desc="Find short-lived cfromptr")

        for node in self.dataset.vertices():
            progress.advance()
            node_data = self.dataset.vp.data[node]

            if node_data.origin == CheriNodeOrigin.FROMPTR:
                vertex_mask[node] = True
            # if (node_data.origin == CheriNodeOrigin.FROMPTR and
            #     len(node_data.address) == 0 and
            #     len(node_data.deref["load"]) == 0 and
            #     len(node_data.deref["load"]) == 0):
            #     # remove cfromptr that are never stored or used in
            #     # a dereference
            #     remove_list.append(node)
        progress.finish()

        self.dataset.set_vertex_filter(vertex_mask, inverted=True)
Esempio n. 8
0
    def plot(self):
        graph_size = self.dataset.num_vertices()
        # (addr, num_allocations)
        addresses = {}
        page_use = {}
        page_size = 2**12

        # address reuse metric
        # num_allocations vs address
        # linearly and in 4k chunks
        tree_progress = ProgressPrinter(graph_size, desc="Fetching addresses")
        for node in self.dataset.vertices():
            data = self.dataset.vp.data[node]
            for time, addr in data.address.items():
                try:
                    addresses[addr] += 1
                except KeyError:
                    addresses[addr] = 1
                page_addr = addr & (~0xfff)
                try:
                    page_use[page_addr] += 1
                except KeyError:
                    page_use[page_addr] = 1
            tree_progress.advance()
        tree_progress.finish()

        # time vs address
        # address working set over time

        fig = plt.figure(figsize=(15, 10))
        ax = fig.add_axes([
            0.05,
            0.15,
            0.9,
            0.80,
        ],
                          projection="custom_addrspace")
        ax.set_ylabel("Number of pointers stored")
        ax.set_xlabel("Virtual address")
        ax.set_yscale("log")
        # ax.set_ylim(0, )
        data = np.array(sorted(page_use.items(), key=lambda i: i[0]))
        # ignore empty address-space chunks
        prev_addr = data[0]
        omit_ranges = []
        first_tick = data[0][0]
        ticks = [first_tick]
        labels = ["0x%x" % int(first_tick)]
        for addr in data:
            logger.debug("DATA 0x%x (%d)", int(addr[0]), addr[0])
            if addr[0] - prev_addr[0] > 2**12:
                omit_ranges.append([prev_addr[0] + page_size, addr[0]])
                ticks.append(addr[0])
                labels.append("0x%x" % int(addr[0]))
            prev_addr = addr
        ax.set_omit_ranges(omit_ranges)
        # ax.set_xticks(ticks)
        # ax.set_xticklabels(labels, rotation="vertical")
        ax.set_xlim(first_tick - page_size, data[-1][0] + page_size)
        ax.vlines(data[:, 0], [1] * len(data[:, 0]), data[:, 1], color="b")

        fig.savefig(self._get_plot_file())
        return fig
Esempio n. 9
0
    def _extract_ranges(self):
        """
        Extract ranges from the provenance graph

        XXX for now do the prototype data manipulation here
        with a naive RangeSet object later we may want to
        move it somewhere else with more dedicated solution
        using interval trees
        """
        dataset_progress = ProgressPrinter(
            self.dataset.num_vertices(), desc="Extract frequency of reference")
        range_set = RangeSet()
        for vertex in self.dataset.vertices():
            node = self.dataset.vp.data[vertex]
            logger.debug("Inspect node %s", node)
            r_node = self.DataRange(node.cap.base,
                                    node.cap.base + node.cap.length)
            node_set = RangeSet([r_node])
            # erode r_node until it is fully merged in the range_set
            # the node_set holds intermediate ranges remaining to merge
            while len(node_set):
                logger.debug("merging node")
                # pop first range from rangeset and try to merge it
                r_current = node_set.pop(0)
                # get first overlapping range
                r_overlap = range_set.pop_overlap_range(r_current)
                if r_overlap == None:
                    # no overlap occurred, just add it to the rangeset
                    range_set.append(r_current)
                    logger.debug("-> no overlap")
                    continue
                logger.debug("picked current %s", r_current)
                logger.debug("picked overlap %s", r_overlap)
                # merge r_current and r_overlap data and push any remaining
                # part of r_current back in node_set
                #
                # r_same: referenced count does not change
                # r_inc: referenced count incremented
                # r_rest: pushed back to node_set for later evaluation
                if r_overlap.start <= r_current.start:
                    logger.debug("overlap before current")
                    # 2 possible layouts:
                    #          |------ r_current -------|
                    # |------ r_overlap -----|
                    # |-r_same-|-- r_inc ----|- r_rest -|
                    #
                    # |--------------- r_overlap --------------|
                    # |-r_same-|-------- r_inc ---------|r_same|
                    r_same, other = r_overlap.split(r_current.start)
                    if r_same.size > 0:
                        range_set.append(r_same)

                    if r_current.end >= r_overlap.end:
                        # other is the remaining part of r_overlap
                        # which falls all in r_current, so
                        # r_inc = other
                        other.num_references += 1
                        range_set.append(other)
                        # r_rest must be computed from the end
                        # of r_overlap
                        _, r_rest = r_current.split(r_overlap.end)
                        if r_rest.size > 0:
                            node_set.append(r_rest)
                    else:
                        # other does not fall all in r_current so
                        # split other in r_inc and r_same
                        # r_current is not pushed back because it
                        # was fully covered by r_overlap
                        r_inc, r_same = other.split(r_current.end)
                        r_inc.num_references += 1
                        range_set.append(r_inc)
                        range_set.append(r_same)
                else:
                    logger.debug("current before overlap")
                    # 2 possible layouts:
                    # |------ r_current ---------|
                    #          |------ r_overlap ---------|
                    # |-r_rest-|-- r_inc --------| r_same |
                    #
                    # |------ r_current --------------|
                    #        |--- r_overlap ---|
                    # |r_rest|----- r_inc -----|r_rest|
                    r_rest, other = r_current.split(r_overlap.start)
                    if r_rest.size > 0:
                        node_set.append(r_rest)

                    if r_current.end >= r_overlap.end:
                        # other is the remaining part of r_current
                        # which completely covers r_overlap so
                        # split other in r_inc and r_rest
                        r_inc, r_rest = other.split(r_overlap.end)
                        r_inc.num_references += r_overlap.num_references
                        range_set.append(r_inc)
                        if r_rest.size > 0:
                            node_set.append(r_rest)
                    else:
                        # other does not cover all r_overlap
                        # so r_inc = other and the remaining
                        # part of r_overlap is r_same
                        other.num_references += r_overlap.num_references
                        range_set.append(other)
                        _, r_same = r_overlap.split(r_current.end)
                        range_set.append(r_same)
                logger.debug("merge loop out Range set step %s", range_set)
                logger.debug("merge loop out Node set step %s", node_set)
            logger.debug("Range set step %s", range_set)
            logger.debug("Node set step %s", node_set)
            dataset_progress.advance()
        dataset_progress.finish()
        logger.debug("Range set %s", range_set)
        self.range_set = range_set
Esempio n. 10
0
class TxtTraceCmpParser(CallbackTraceParser):
    """
    Compare a text trace with a binary trace and
    report any difference.
    """
    def __init__(self, txt_trace, *args, pc_only=False, **kwargs):
        super().__init__(*args, **kwargs)

        self.pc_only = pc_only

        self.progress = ProgressPrinter(len(self), "Scan traces")

        # txt trace perser state machine
        self.txt_parse_state = State.S_INSTR
        self.txt_trace = open(txt_trace, "r")
        # skip lines from the txt trace until the first
        # instruction
        self._skiplines(inst_only=True)
        self.txt_parse_state = State.S_INSTR
        # while True:
        #     saved_pos = self.txt_trace.tell()
        #     line = self.txt_trace.readline()
        #     if re.match("[0-9xa-f]+:", line):
        #         self.txt_trace.seek(saved_pos)
        #         break

    def _skiplines(self, inst_only=False):
        """Skip lines that are not used"""

        while True:
            saved_pos = self.txt_trace.tell()
            line = self.txt_trace.readline()
            # test all the pattern that should not be skipped
            if inst_only == False:
                if re.search("Cap Memory Read", line) is not None:
                    self.txt_parse_state = State.S_CAP_MEM
                    break
                if re.search("Cap Memory Write", line) is not None:
                    self.txt_parse_state = State.S_CAP_MEM
                    break
                if re.search("Memory Read", line) is not None:
                    self.txt_parse_state = State.S_MEM
                    break
                if re.search("Memory Write", line) is not None:
                    self.txt_parse_state = State.S_MEM
                    break
                if re.search("Write [C\$]?[a-z0-9]+", line) is not None:
                    self.txt_parse_state = State.S_REG
                    break
            if re.match("[0-9xa-f]+:", line) is not None:
                # the next call to the parser function will
                # continue from here
                self.txt_parse_state = State.S_INSTR_END
                break
        self.txt_trace.seek(saved_pos)

    def _txt_instr(self, inst):
        line = self.txt_trace.readline()
        # line matches "[0-9xa-f]+:"
        # parse addr
        addr, rest = line.split(':')
        _, addr = addr.split("x")
        intaddr = struct.unpack(">Q", bytes.fromhex(addr))[0]
        inst["pc"] = intaddr
        rest = re.sub("[ \t]+", " ", rest.strip())
        opcode = rest.split(" ")[0]
        inst["opcode"] = opcode
        if len(rest.split(" ")) > 1:
            operands = rest.split(" ")[1]
            op0 = operands.split(",")[0]
        else:
            op0 = None

        # if we find a li zero, <something> is a canonical nop so
        # we need to skip until the next instruction is found
        if inst["opcode"] == "li" and op0 == "zero":
            self._skiplines(inst_only=True)
        else:
            # seek to next valid line and change state
            self._skiplines()

    def _txt_reg(self, inst):
        line = self.txt_trace.readline()
        m = re.search("Write \$?([a-z0-9]+) = ([a-f0-9]+)", line)
        if m:
            # write to gpr format
            # Write t4 = 0000000000008400
            reg = m.group(1)
            val = m.group(2)
            intval = struct.unpack(">Q", bytes.fromhex(val))[0]
            inst["reg"] = reg
            inst["data"] = intval
        else:
            # write to cap register format
            # Write C24|v:1 s:0 p:7fff807d b:0000007fffffdb20 l:0000000000000400
            # |o:0000000000000000 t:0
            m = re.search(
                "Write C([0-9]+)\|v:([01]) s:([01]) p:([a-f0-9]+) "
                "b:([a-f0-9]+) l:([a-f0-9]+)", line)
            if m is None:
                raise RuntimeError("Malformed cap reg write")
            # first line of a capability match
            # next line must match this
            line = self.txt_trace.readline()
            nxt = re.search("\|o:([a-f0-9]+) t:([a-f0-9]+)", line)
            if nxt is None:
                raise RuntimeError("Malformed cap reg write")
            v = m.group(2)
            s = m.group(3)
            p = m.group(4)
            b = m.group(5)
            l = m.group(6)
            o = nxt.group(1)
            t = nxt.group(2)
            try:
                if len(t) % 2:
                    # hotfix fromhex() that do not like odd num of digits
                    t = "0" + t
                t = bytes.fromhex(t)
                if len(t) < 4:
                    for i in range(4 - len(t)):
                        t = bytes.fromhex("00") + t
            except Exception:
                logger.error("Can not load type field %s %s", m.groups(),
                             nxt.groups())
                raise
            # take only 16bit for permissions, the upper 16bit
            # are stored in the trace but ignored by cheritrace
            # as we do not care about uperms apparently.
            intp = struct.unpack(">L", bytes.fromhex(p))[0] & 0xffff
            intb = struct.unpack(">Q", bytes.fromhex(b))[0]
            intl = struct.unpack(">Q", bytes.fromhex(l))[0]
            into = struct.unpack(">Q", bytes.fromhex(o))[0]
            intt = struct.unpack(">L", t)[0] & 0x00ffffff
            inst["cap"] = {
                "valid": int(v),
                "sealed": int(s),
                "perms": intp,
                "base": intb,
                "length": intl,
                "offset": into,
                "otype": intt,
            }
        # seek to next valid line and change state
        self._skiplines()

    def _txt_mem(self, inst):
        line = self.txt_trace.readline()
        m = re.search("(Cap )?Memory Read +\[([0-9a-f]+)\]", line)
        if m:
            # data load
            is_cap = m.group(1)
            addr = m.group(2)
            intaddr = struct.unpack(">Q", bytes.fromhex(addr))[0]
            inst["load"] = intaddr
            if is_cap:
                # skip another line
                self.txt_trace.readline()
        else:
            m = re.search("(Cap )?Memory Write +\[([0-9a-f]+)\]", line)
            if m is None:
                raise RuntimeError("Mem not a read nor a write")
            #data store
            is_cap = m.group(1)
            addr = m.group(2)
            intaddr = struct.unpack(">Q", bytes.fromhex(addr))[0]
            inst["store"] = intaddr
            if is_cap:
                # skip another line
                self.txt_trace.readline()
        # seek to next valid line and change state
        self._skiplines()

    def _next_txt_instr(self):
        """
        Fetch the next instruction from the txt trace.
        This is the state machine main loop.
        """
        instr = {}

        while self.txt_parse_state != State.S_INSTR_END:
            if self.txt_parse_state == State.S_SKIP:
                self._skiplines()
            elif self.txt_parse_state == State.S_INSTR:
                self._txt_instr(instr)
            elif self.txt_parse_state == State.S_REG:
                self._txt_reg(instr)
            elif self.txt_parse_state == State.S_MEM:
                self._txt_mem(instr)
            elif self.txt_parse_state == State.S_CAP_MEM:
                self._txt_mem(instr)
        # next call starts always from an instruction
        self.txt_parse_state = State.S_INSTR
        return instr

    def _dump_txt_inst(self, txt_inst):
        string = "pc:0x%x %s" % (txt_inst["pc"], txt_inst["opcode"])
        if "load" in txt_inst:
            string += " load:%x" % txt_inst["load"]
        if "store" in txt_inst:
            string += " store:%x" % txt_inst["store"]
        if "data" in txt_inst:
            string += " val:%x" % txt_inst["data"]
        if "cap" in txt_inst:
            txt_cap = txt_inst["cap"]
            string += " v:%d s:%d b:%x o:%x l:%x p:%x t:%x" % (
                txt_cap["valid"], txt_cap["sealed"], txt_cap["base"],
                txt_cap["offset"], txt_cap["length"], txt_cap["perms"],
                txt_cap["otype"])
        return string

    def _parse_exception(self, entry, regs, disasm, idx):
        super()._parse_exception(entry, regs, disasm, idx)

        # read entry from
        txt_inst = self._next_txt_instr()
        logger.debug("Scan txt:<%s>, bin:<unparsed>",
                     self._dump_txt_inst(txt_inst))
        # check only pc which must be valid anyway
        assert txt_inst["pc"] == entry.pc

    def scan_all(self, inst, entry, regs, last_regs, idx):

        # read entry from
        txt_inst = self._next_txt_instr()
        logger.debug("Scan txt:<%s>, bin:%s", self._dump_txt_inst(txt_inst),
                     inst)
        try:
            # check that the instruction matches
            assert txt_inst["pc"] == entry.pc
            if self.pc_only:
                # only check pc, skip everything else
                return False
            if inst.opcode in ["mfc0"]:
                # these have weird behaviour so just ignore for now
                return False

            if txt_inst["opcode"] != inst.opcode:
                # opcode check is not mandatory due to disassembly differences
                # issue a warning anyway for now
                logger.warning("Opcode differ {%d} txt:<%s> bin:%s",
                               entry.cycles, self._dump_txt_inst(txt_inst),
                               inst)
            if "load" in txt_inst:
                assert txt_inst["load"] == entry.memory_address
            if "store" in txt_inst:
                assert txt_inst["store"] == entry.memory_address
            if "data" in txt_inst:
                if inst.opcode not in ["mfc0"]:
                    reg_number = entry.gpr_number()
                    for op in inst.operands:
                        if op.is_register and op.gpr_index == reg_number:
                            logger.debug("gpr:%d reg:%d")
                            assert txt_inst["data"] == op.value, \
                                "reg data do not match %d != %d" % (
                                    txt_inst["data"], op.value)
                            break
                #     # XXX we have a problem with extracting the jump target
                #     # from jal/j the binary trace have an offset that does
                #     # not make much sense..
                #     assert txt_inst["data"] == inst.op0.value
            if "cap" in txt_inst:
                cap = CheriCap(inst.op0.value)
                txt_cap = txt_inst["cap"]
                assert txt_cap["valid"] == cap.valid, \
                    "tag do not match %d != %d" % (
                        txt_cap["valid"], cap.valid)
                assert txt_cap["sealed"] == cap.sealed, \
                    "seal do not match %d != %d" % (
                        txt_cap["sealed"], cap.sealed)
                assert txt_cap["base"] == cap.base, \
                    "base do not match %x != %x" % (
                        txt_cap["base"], cap.base)
                assert txt_cap["length"] == cap.length, \
                    "length do not match %x != %x" % (
                        txt_cap["length"], cap.length)
                assert txt_cap["offset"] == cap.offset, \
                    "offset do not match %x != %x" % (
                        txt_cap["offset"], cap.offset)
                assert txt_cap["perms"] == cap.permissions, \
                    "perms do not match %x != %x" % (
                        txt_cap["perms"], cap.permissions)
                assert txt_cap["otype"] == cap.objtype, \
                    "otype do not match %x != %x" % (
                        txt_cap["otype"], cap.objtype)

        except AssertionError:
            logger.error("Assertion failed at {%d} inst:%s txt:<%s>",
                         entry.cycles, inst, self._dump_txt_inst(txt_inst))
            raise
        self.progress.advance()
        return False
Esempio n. 11
0
class CallbackTraceParser(TraceParser):
    """
    Trace parser that provides help to filter
    and normalize instructions.

    This class performs the filtering of instructions
    that are interesting to the parser and calls the appropriate
    callback if it is defined.
    Callback methods must start with "scan_" followed by the opcode
    or instruction class (e.g. scan_ld will be invoked every time an
    "ld" instruction is found, scan_cap_load will be invoked every time
    a load or store through a capability is found).
    The callback must have the follwing signature:
    scan_<name>(inst, entry, regs, last_regs, idx).

    Valid instruction class names are:

    * all: all instructions
    * cap: all capability instructions
    * cap_load: all capability load
    * cap_store: all capability store
    * cap_arith: all capability pointer manipulation
    * cap_bound: all capability bound modification
    * cap_cast: all conversions from and to capability pointers
    * cap_cpreg: all manipulations of ddc, kdc, epcc, kcc
    * cap_other: all capability instructions that do not fall in
    the previous "cap_" classes
    """
    def __init__(self, dataset, trace_path, **kwargs):
        super(CallbackTraceParser, self).__init__(trace_path, **kwargs)

        self.dataset = dataset
        """The dataset where the parsed data will be stored"""

        self.progress = ProgressPrinter(len(self),
                                        desc="Scanning trace %s" % trace_path)
        """Progress object to display feedback to the user"""

        self._last_regs = None
        """Snapshot of the registers of the previous instruction"""

        self._dis = pct.disassembler()
        """Disassembler"""

        # Enumerate the callbacks at creation time to save
        # time during scanning
        self._callbacks = {}

        # for each opcode we may be interested in, check if there is
        # one or more callbacks to call, if so these will be stored
        # in _callbacks[<opcode>] so that the _get_callbacks function
        # can retrieve them in ~O(1)
        for attr in dir(self):
            method = getattr(self, attr)
            if (not attr.startswith("scan_") or not callable(method)):
                continue
            instr_name = attr[5:]
            for iclass in Instruction.IClass:
                if instr_name == iclass.value:
                    # add the iclass callback for all the
                    # instructions in such class
                    opcodes = Instruction.iclass_map.get(iclass, [])
                    for opcode in opcodes:
                        if opcode in self._callbacks:
                            self._callbacks[opcode].append(method)
                        else:
                            self._callbacks[opcode] = [method]
                    break
            else:
                if instr_name in self._callbacks:
                    self._callbacks[instr_name] += [method]
                else:
                    self._callbacks[instr_name] = [method]

        logger.debug("Loaded callbacks for CallbackTraceParser %s",
                     self._callbacks)

    def _get_callbacks(self, inst):
        """
        Return a list of callback methods that should be called to
        parse this instruction

        :param inst: instruction object for the current instruction
        :type inst: :class:`.Instruction`
        :return: list of methods to be called
        :rtype: list of callables
        """
        # try to get the callback for all instructions, if any
        callbacks = list(self._callbacks.get("all", []))
        # the <all> callback should be the last one executed
        callbacks = self._callbacks.get(inst.opcode, []) + callbacks
        return callbacks

    def _parse_exception(self, entry, regs, disasm, idx):
        """
        Callback invoked when an instruction could not be parsed
        XXX make this debug because the mul instruction always fails
        and it is too verbose but should report it as a warning/error
        """
        logger.debug("Error parsing instruction #%d pc:0x%x: %s raw: 0x%x",
                     entry.cycles, entry.pc, disasm.name, entry.inst)

    def parse(self, start=None, end=None, direction=0):
        """
        Parse the trace

        For each trace entry a callback is invoked, some classes
        of instructions cause the callback for the group to be called,
        e.g. scan_cap_load is called whenever a load from memory through
        a capability is found.

        Each instruction opcode can have a callback in the form
        scan_<opcode>.

        :param start: index of the first trace entry to scan
        :type start: int
        :param end: index of the last trace entry to scan
        :type end: int
        :param direction: scan direction (forward = 0, backward=1)
        :type direction: int
        """

        if start is None:
            start = 0
        if end is None:
            end = len(self)
        # fast progress processing, calling progress.advance() in each
        # _scan call is too expensive
        progress_points = list(range(start, end, int((end - start) / 100) + 1))
        progress_points.append(end)

        def _scan(entry, regs, idx):
            if idx >= progress_points[0]:
                progress_points.pop(0)
                self.progress.advance(to=idx)
            disasm = self._dis.disassemble(entry.inst)
            try:
                if self._last_regs is None:
                    self._last_regs = regs
                inst = Instruction(disasm, entry, regs, self._last_regs)
            except Exception as e:
                self._parse_exception(entry, regs, disasm, idx)
                return False

            ret = False

            try:
                for cbk in self._get_callbacks(inst):
                    ret |= cbk(inst, entry, regs, self._last_regs, idx)
                    if ret:
                        break
            except Exception as e:
                logger.error("Error in callback %s: %s", cbk, e)
                raise

            self._last_regs = regs
            return ret

        self.trace.scan(_scan, start, end, direction)
        self.progress.finish()