Exemplo n.º 1
0
    def get_symbols(self, rom_bytes) -> List[Symbol]:
        symset = set()
        endian = options.get_endianess()

        # Find inter-data symbols
        for i in range(self.rom_start, self.rom_end, 4):
            bits = int.from_bytes(rom_bytes[i : i + 4], endian)
            if self.contains_vram(bits):
                symset.add(self.get_most_parent().get_symbol(bits, create=True, define=True, local_only=True))

        for symbol_addr in self.seg_symbols:
            for symbol in self.seg_symbols[symbol_addr]:
                if not symbol.dead and self.contains_vram(symbol.vram_start):
                    symset.add(symbol)

        ret: List[Symbol] = list(symset)
        ret.sort(key=lambda s:s.vram_start)

        # Ensure we start at the beginning
        if len(ret) == 0 or ret[0].vram_start != self.vram_start:
            ret.insert(0, self.get_most_parent().get_symbol(self.vram_start, create=True, define=True, local_only=True))

        # Make a dummy symbol here that marks the end of the previous symbol's disasm range
        ret.append(Symbol(self.vram_end))

        while True:
            valid = self.check_jtbls(rom_bytes, ret)
            if valid:
                break

        return ret
Exemplo n.º 2
0
    def check_jtbls(self, rom_bytes, syms: List[Symbol]):
        endian = options.get_endianess()
        for i, sym in enumerate(syms):
            if sym.type == "jtbl":
                start = self.get_most_parent().ram_to_rom(syms[i].vram_start)
                assert isinstance(start, int)
                end = self.get_most_parent().ram_to_rom(syms[i + 1].vram_start)
                sym_bytes = rom_bytes[start:end]

                b = 0
                last_bits = 0
                while b < len(sym_bytes):
                    bits = int.from_bytes(sym_bytes[b : b + 4], endian)

                    if last_bits != 0 and bits != 0 and abs(last_bits - bits) > 0x100000:
                        new_sym_rom_start = start + b
                        new_sym_ram_start = self.get_most_parent().rom_to_ram(new_sym_rom_start)
                        sym.size = new_sym_rom_start - sym.rom

                        syms.insert(i + 1, self.get_most_parent().get_symbol(new_sym_ram_start, create=True, define=True, local_only=True))
                        return False

                    if bits != 0:
                        last_bits = bits
                    b += 4

        return True
Exemplo n.º 3
0
    def is_valid_jtbl(self, sym: Symbol, bytes) -> bool:
        min_jtbl_len = 16

        if len(bytes) % 4 != 0:
            return False

        # Jump tables must have at least 3 labels
        if len(bytes) < min_jtbl_len:
            return False

        most_parent = self.get_most_parent()
        assert isinstance(most_parent, CommonSegCode)

        # Grab the first word and see if its value is an address within a function
        word = int.from_bytes(bytes[0:4], options.get_endianess())
        jtbl_func: Optional[Symbol] = self.get_most_parent().get_func_for_addr(word)

        if not jtbl_func:
            return False

        # A label of a jump table shouldn't point to the start of the function
        if word == jtbl_func.vram_start:
            return False

        for i in range(4, len(bytes), 4):
            word = int.from_bytes(bytes[i : i + 4], options.get_endianess())

            # If the word doesn't contain an address in the current function, this isn't a valid jump table
            if not jtbl_func.contains_vram(word):
                # Allow jump tables that are of a minimum length and end in 0s
                if i < min_jtbl_len or any(b != 0 for b in bytes[i:]):
                    return False

            # A label of a jump table shouldn't point to the start of the function
            if word == jtbl_func.vram_start:
                return False

        # Mark this symbol as a jump table and record the jump table for later
        sym.type = "jtbl"
        most_parent.jumptables[sym.vram_start] = (
            jtbl_func.vram_start,
            jtbl_func.vram_end,
        )
        return True
Exemplo n.º 4
0
def unpack_color(data):
    s = int.from_bytes(data[0:2], byteorder=options.get_endianess())

    r = (s >> 11) & 0x1F
    g = (s >> 6) & 0x1F
    b = (s >> 1) & 0x1F
    a = (s & 1) * 0xFF

    r = ceil(0xFF * (r / 31))
    g = ceil(0xFF * (g / 31))
    b = ceil(0xFF * (b / 31))

    return r, g, b, a
Exemplo n.º 5
0
    def disassemble_data(self, rom_bytes):
        gfx_data = rom_bytes[self.rom_start : self.rom_end]
        segment_length = len(gfx_data)
        if (segment_length) % 8 != 0:
            error(
                f"Error: gfx segment {self.name} length ({segment_length}) is not a multiple of 8!"
            )

        out_str = options.get_generated_c_premble() + "\n\n"

        sym = self.create_symbol(
            addr=self.vram_start, in_segment=True, type="data", define=True
        )

        gfxd_input_buffer(gfx_data)

        # TODO terrible guess at the size we'll need - improve this
        outb = bytes([0] * segment_length * 100)
        outbuf = gfxd_output_buffer(outb, len(outb))

        gfxd_target(self.get_gfxd_target())
        gfxd_endian(
            GfxdEndian.big if options.get_endianess() == "big" else GfxdEndian.little, 4
        )

        # Callbacks
        gfxd_macro_fn(self.macro_fn)

        gfxd_tlut_callback(self.tlut_handler)
        gfxd_timg_callback(self.timg_handler)
        gfxd_cimg_callback(self.cimg_handler)
        gfxd_zimg_callback(self.zimg_handler)
        gfxd_dl_callback(self.dl_handler)
        gfxd_mtx_callback(self.mtx_handler)
        gfxd_lookat_callback(self.lookat_handler)
        gfxd_light_callback(self.light_handler)
        # gfxd_seg_callback ?
        gfxd_vtx_callback(self.vtx_handler)
        gfxd_vp_callback(self.vp_handler)
        # gfxd_uctext_callback ?
        # gfxd_ucdata_callback ?
        # gfxd_dram_callback ?

        gfxd_execute()
        out_str += "Gfx " + sym.name + "[] = {\n"
        out_str += gfxd_buffer_to_string(outbuf)
        out_str += "};\n"
        return out_str
Exemplo n.º 6
0
    def gather_jumptable_labels(self, rom_bytes):
        # TODO: use the seg_symbols for this
        # jumptables = [j.type == "jtbl" for j in self.seg_symbols]
        for jumptable in self.parent.jumptables:
            start, end = self.parent.jumptables[jumptable]
            rom_offset = self.rom_start + jumptable - self.vram_start

            if rom_offset <= 0:
                return

            while rom_offset:
                word = rom_bytes[rom_offset:rom_offset + 4]
                word_int = int.from_bytes(word, options.get_endianess())
                if word_int >= start and word_int <= end:
                    self.parent.jtbl_glabels_to_add.add(word_int)
                else:
                    break

                rom_offset += 4
Exemplo n.º 7
0
    def check_jtbls(self, rom_bytes, syms: List[Symbol]):
        endianness = options.get_endianess()

        for i, sym in enumerate(syms):
            if sym.type == "jtbl":
                start = self.get_most_parent().ram_to_rom(syms[i].vram_start)
                assert isinstance(start, int)
                end = self.get_most_parent().ram_to_rom(syms[i + 1].vram_start)
                sym_bytes = rom_bytes[start:end]

                b = 0
                last_bits = 0
                while b < len(sym_bytes):
                    bits = int.from_bytes(sym_bytes[b : b + 4], endianness)

                    if (
                        last_bits != 0
                        and bits != 0
                        and abs(last_bits - bits) > 0x100000
                    ):
                        new_sym_rom_start = start + b
                        new_sym_ram_start = self.get_most_parent().rom_to_ram(
                            new_sym_rom_start
                        )
                        assert sym.rom is not None
                        assert new_sym_ram_start is not None
                        sym.given_size = new_sym_rom_start - sym.rom

                        # It turns out this isn't a valid jump table, so create a new symbol where it breaks
                        syms.insert(
                            i + 1,
                            self.create_symbol(
                                new_sym_ram_start, True, define=True, local_only=True
                            ),
                        )
                        return False

                    if bits != 0:
                        last_bits = bits
                    b += 4

        return True
Exemplo n.º 8
0
    def process_insns(self, insns, rom_addr):
        assert (isinstance(self.parent, CommonSegCode))
        self.parent: CommonSegCode = self.parent

        ret = OrderedDict()

        func_addr = None
        func = []
        end_func = False
        labels = []

        big_endian = options.get_endianess() == "big"

        # Collect labels
        for insn in insns:
            if self.is_branch_insn(insn.mnemonic):
                op_str_split = insn.op_str.split(" ")
                branch_target = op_str_split[-1]
                branch_addr = int(branch_target, 0)
                labels.append((insn.address, branch_addr))

        # Main loop
        for i, insn in enumerate(insns):
            mnemonic = insn.mnemonic
            op_str = insn.op_str
            func_addr = insn.address if len(func) == 0 else func[0][0].address

            # If this is non-zero, disasm size insns
            hard_size = 0
            func_sym = self.parent.get_symbol(func_addr, type="func")
            if func_sym and func_sym.size > 4:
                hard_size = func_sym.size / 4

            if mnemonic == "move":
                # Let's get the actual instruction out
                idx = 3 if big_endian else 0
                opcode = insn.bytes[idx] & 0b00111111

                op_str += ", $zero"

                if opcode == 37:
                    mnemonic = "or"
                elif opcode == 45:
                    mnemonic = "daddu"
                elif opcode == 33:
                    mnemonic = "addu"
                else:
                    print("INVALID INSTRUCTION " + str(insn), opcode)
            elif mnemonic == "jal":
                jal_addr = int(op_str, 0)
                jump_func = self.parent.get_symbol(jal_addr,
                                                   type="func",
                                                   create=True,
                                                   reference=True)
                op_str = jump_func.name
            elif self.is_branch_insn(insn.mnemonic):
                op_str_split = op_str.split(" ")
                branch_target = op_str_split[-1]
                branch_target_int = int(branch_target, 0)
                label = ""

                label = self.parent.get_symbol(branch_target_int,
                                               type="label",
                                               reference=True,
                                               local_only=True)

                if label:
                    label_name = label.name
                else:
                    self.parent.labels_to_add.add(branch_target_int)
                    label_name = f".L{branch_target[2:].upper()}"

                op_str = " ".join(op_str_split[:-1] + [label_name])
            elif mnemonic in ["mtc0", "mfc0", "mtc2", "mfc2"]:
                idx = 2 if big_endian else 1
                rd = (insn.bytes[idx] & 0xF8) >> 3
                op_str = op_str.split(" ")[0] + " $" + str(rd)

            func.append((insn, mnemonic, op_str, rom_addr))
            rom_addr += 4

            size_remaining = hard_size - len(func) if hard_size > 0 else 0

            if mnemonic == "jr":
                # Record potential jtbl jumps
                if op_str != "$ra":
                    self.parent.jtbl_jumps[insn.address] = op_str

                keep_going = False
                for label in labels:
                    if (label[0] > insn.address and label[1] <= insn.address
                        ) or (label[0] <= insn.address
                              and label[1] > insn.address):
                        keep_going = True
                        break
                if not keep_going and not size_remaining:
                    end_func = True
                    continue

            # Stop here if a size was specified and we have disassembled up to the size
            if hard_size > 0 and size_remaining == 0:
                end_func = True

            if i < len(insns) - 1 and self.parent.get_symbol(
                    insns[i + 1].address, local_only=True, type="func",
                    dead=False):
                end_func = True

            if end_func:
                if self.is_nops(
                        insns[i:]
                ) or i < len(insns) - 1 and insns[i + 1].mnemonic != "nop":
                    end_func = False
                    ret[func_addr] = func
                    func = []

        # Add the last function (or append nops to the previous one)
        if not self.is_nops([i[0] for i in func]):
            ret[func_addr] = func
        else:
            next(reversed(ret.values())).extend(func)

        return ret
Exemplo n.º 9
0
class CommonSegCodeSubsegment(Segment):
    double_mnemonics = ["ldc1", "sdc1"]
    word_mnemonics = ["addiu", "sw", "lw", "jtbl"]
    float_mnemonics = ["lwc1", "swc1"]
    short_mnemonics = ["addiu", "lh", "sh", "lhu"]
    byte_mnemonics = ["lb", "sb", "lbu"]

    if options.get_endianess() == "big":
        capstone_mode = CS_MODE_MIPS64 | CS_MODE_BIG_ENDIAN
    else:
        capstone_mode = CS_MODE_MIPS32 | CS_MODE_LITTLE_ENDIAN

    md = Cs(CS_ARCH_MIPS, capstone_mode)
    md.detail = True
    md.skipdata = True

    @property
    def needs_symbols(self) -> bool:
        return True

    def get_linker_section(self) -> str:
        return ".text"

    @staticmethod
    def is_nops(insns):
        for insn in insns:
            if insn.mnemonic != "nop":
                return False
        return True

    @staticmethod
    def is_branch_insn(mnemonic):
        return (mnemonic.startswith("b") and not mnemonic.startswith("binsl")
                and not mnemonic == "break") or mnemonic == "j"

    def disassemble_code(self, rom_bytes, addsuffix=False):
        insns = [
            insn for insn in CommonSegCodeSubsegment.md.disasm(
                rom_bytes[self.rom_start:self.rom_end], self.vram_start)
        ]

        funcs = self.process_insns(insns, self.rom_start)

        # TODO: someday make func a subclass of symbol and store this disasm info there too
        for func in funcs:
            self.parent.get_symbol(func,
                                   type="func",
                                   create=True,
                                   define=True,
                                   local_only=True)

        funcs = self.determine_symbols(funcs)
        self.gather_jumptable_labels(rom_bytes)
        return self.add_labels(funcs, addsuffix)

    def process_insns(self, insns, rom_addr):
        assert (isinstance(self.parent, CommonSegCode))
        self.parent: CommonSegCode = self.parent

        ret = OrderedDict()

        func_addr = None
        func = []
        end_func = False
        labels = []

        big_endian = options.get_endianess() == "big"

        # Collect labels
        for insn in insns:
            if self.is_branch_insn(insn.mnemonic):
                op_str_split = insn.op_str.split(" ")
                branch_target = op_str_split[-1]
                branch_addr = int(branch_target, 0)
                labels.append((insn.address, branch_addr))

        # Main loop
        for i, insn in enumerate(insns):
            mnemonic = insn.mnemonic
            op_str = insn.op_str
            func_addr = insn.address if len(func) == 0 else func[0][0].address

            # If this is non-zero, disasm size insns
            hard_size = 0
            func_sym = self.parent.get_symbol(func_addr, type="func")
            if func_sym and func_sym.size > 4:
                hard_size = func_sym.size / 4

            if mnemonic == "move":
                # Let's get the actual instruction out
                idx = 3 if big_endian else 0
                opcode = insn.bytes[idx] & 0b00111111

                op_str += ", $zero"

                if opcode == 37:
                    mnemonic = "or"
                elif opcode == 45:
                    mnemonic = "daddu"
                elif opcode == 33:
                    mnemonic = "addu"
                else:
                    print("INVALID INSTRUCTION " + str(insn), opcode)
            elif mnemonic == "jal":
                jal_addr = int(op_str, 0)
                jump_func = self.parent.get_symbol(jal_addr,
                                                   type="func",
                                                   create=True,
                                                   reference=True)
                op_str = jump_func.name
            elif self.is_branch_insn(insn.mnemonic):
                op_str_split = op_str.split(" ")
                branch_target = op_str_split[-1]
                branch_target_int = int(branch_target, 0)
                label = ""

                label = self.parent.get_symbol(branch_target_int,
                                               type="label",
                                               reference=True,
                                               local_only=True)

                if label:
                    label_name = label.name
                else:
                    self.parent.labels_to_add.add(branch_target_int)
                    label_name = f".L{branch_target[2:].upper()}"

                op_str = " ".join(op_str_split[:-1] + [label_name])
            elif mnemonic in ["mtc0", "mfc0", "mtc2", "mfc2"]:
                idx = 2 if big_endian else 1
                rd = (insn.bytes[idx] & 0xF8) >> 3
                op_str = op_str.split(" ")[0] + " $" + str(rd)

            func.append((insn, mnemonic, op_str, rom_addr))
            rom_addr += 4

            size_remaining = hard_size - len(func) if hard_size > 0 else 0

            if mnemonic == "jr":
                # Record potential jtbl jumps
                if op_str != "$ra":
                    self.parent.jtbl_jumps[insn.address] = op_str

                keep_going = False
                for label in labels:
                    if (label[0] > insn.address and label[1] <= insn.address
                        ) or (label[0] <= insn.address
                              and label[1] > insn.address):
                        keep_going = True
                        break
                if not keep_going and not size_remaining:
                    end_func = True
                    continue

            # Stop here if a size was specified and we have disassembled up to the size
            if hard_size > 0 and size_remaining == 0:
                end_func = True

            if i < len(insns) - 1 and self.parent.get_symbol(
                    insns[i + 1].address, local_only=True, type="func",
                    dead=False):
                end_func = True

            if end_func:
                if self.is_nops(
                        insns[i:]
                ) or i < len(insns) - 1 and insns[i + 1].mnemonic != "nop":
                    end_func = False
                    ret[func_addr] = func
                    func = []

        # Add the last function (or append nops to the previous one)
        if not self.is_nops([i[0] for i in func]):
            ret[func_addr] = func
        else:
            next(reversed(ret.values())).extend(func)

        return ret

    def update_access_mnemonic(self, sym, mnemonic):
        if not sym.access_mnemonic:
            sym.access_mnemonic = mnemonic
        elif sym.access_mnemonic == "addiu":
            sym.access_mnemonic = mnemonic
        elif sym.access_mnemonic in self.double_mnemonics:
            return
        elif sym.access_mnemonic in self.float_mnemonics and mnemonic in self.double_mnemonics:
            sym.access_mnemonic = mnemonic
        elif sym.access_mnemonic in self.short_mnemonics:
            return
        elif sym.access_mnemonic in self.byte_mnemonics:
            return
        else:
            sym.access_mnemonic = mnemonic

    # Determine symbols
    def determine_symbols(self, funcs):
        hi_lo_max_distance = options.get("hi_lo_max_distance", 6)
        ret = {}

        for func_addr in funcs:
            func = funcs[func_addr]
            func_end_addr = func[-1][0].address + 4

            possible_jtbl_jumps = [(k, v)
                                   for k, v in self.parent.jtbl_jumps.items()
                                   if k >= func_addr and k < func_end_addr]
            possible_jtbl_jumps.sort(key=lambda x: x[0])

            for i in range(len(func)):
                insn = func[i][0]

                # Ensure the first item in the list is always ahead of where we're looking
                while len(possible_jtbl_jumps
                          ) > 0 and possible_jtbl_jumps[0][0] < insn.address:
                    del possible_jtbl_jumps[0]

                if insn.mnemonic == "lui":
                    op_split = insn.op_str.split(", ")
                    reg = op_split[0]

                    if not op_split[1].startswith("0x"):
                        continue

                    lui_val = int(op_split[1], 0)
                    if lui_val >= 0x8000:
                        for j in range(i + 1,
                                       min(i + hi_lo_max_distance, len(func))):
                            s_insn = func[j][0]

                            s_op_split = s_insn.op_str.split(", ")

                            if s_insn.mnemonic == "lui" and reg == s_op_split[
                                    0]:
                                break

                            if s_insn.mnemonic in ["addiu", "ori"]:
                                s_reg = s_op_split[-2]
                            else:
                                s_reg = s_op_split[-1][s_op_split[-1].
                                                       rfind("(") + 1:-1]

                            if reg == s_reg:
                                if s_insn.mnemonic not in [
                                        "addiu", "lw", "sw", "lh", "sh", "lhu",
                                        "lb", "sb", "lbu", "lwc1", "swc1",
                                        "ldc1", "sdc1"
                                ]:
                                    break

                                # Match!
                                reg_ext = ""

                                junk_search = re.search(
                                    r"[\(]", s_op_split[-1])
                                if junk_search is not None:
                                    if junk_search.start() == 0:
                                        break
                                    s_str = s_op_split[-1][:junk_search.start(
                                    )]
                                    reg_ext = s_op_split[-1][junk_search.start(
                                    ):]
                                else:
                                    s_str = s_op_split[-1]

                                symbol_addr = (lui_val * 0x10000) + int(
                                    s_str, 0)

                                sym = None
                                offset_str = ""

                                if symbol_addr > func_addr and symbol_addr < self.parent.vram_end and len(
                                        possible_jtbl_jumps
                                ) > 0 and func_end_addr - s_insn.address >= 0x30:
                                    for jump in possible_jtbl_jumps:
                                        if jump[1] == s_op_split[0]:
                                            dist_to_jump = possible_jtbl_jumps[
                                                0][0] - s_insn.address
                                            if dist_to_jump <= 16:
                                                sym = self.parent.get_symbol(
                                                    symbol_addr,
                                                    create=True,
                                                    reference=True,
                                                    type="jtbl",
                                                    local_only=True)

                                                self.parent.jumptables[
                                                    symbol_addr] = (
                                                        func_addr,
                                                        func_end_addr)
                                                break

                                if not sym:
                                    sym = self.parent.get_symbol(
                                        symbol_addr,
                                        create=True,
                                        offsets=True,
                                        reference=True)
                                    offset = symbol_addr - sym.vram_start
                                    if offset != 0:
                                        offset_str = f"+0x{offset:X}"

                                if self.parent:
                                    self.parent.check_rodata_sym(
                                        func_addr, sym)

                                self.update_access_mnemonic(
                                    sym, s_insn.mnemonic)

                                sym_label = sym.name + offset_str

                                func[i] += ("%hi({})".format(sym_label), )
                                func[j] += ("%lo({}){}".format(
                                    sym_label, reg_ext), )
                                break
            ret[func_addr] = func
        return ret

    def add_labels(self, funcs, addsuffix):
        ret = {}

        for func in funcs:
            func_text = []

            # Add function glabel
            rom_addr = funcs[func][0][3]
            sym = self.parent.get_symbol(func,
                                         type="func",
                                         create=True,
                                         define=True,
                                         local_only=True)
            func_text.append(f"glabel {sym.name}")

            indent_next = False

            mnemonic_ljust = options.get("mnemonic_ljust", 11)
            rom_addr_padding = options.get("rom_address_padding", None)

            for insn in funcs[func]:
                insn_addr = insn[0].address
                # Add a label if we need one
                if insn_addr in self.parent.jtbl_glabels_to_add:
                    func_text.append(f"glabel L{insn_addr:X}_{insn[3]:X}")
                elif insn_addr in self.parent.labels_to_add:
                    self.parent.labels_to_add.remove(insn_addr)
                    func_text.append(".L{:X}:".format(insn_addr))

                if rom_addr_padding:
                    rom_str = "{0:0{1}X}".format(insn[3], rom_addr_padding)
                else:
                    rom_str = "{:X}".format(insn[3])

                asm_comment = "/* {} {:X} {} */".format(
                    rom_str, insn_addr, insn[0].bytes.hex().upper())

                if len(insn) > 4:
                    op_str = ", ".join(insn[2].split(", ")[:-1] + [insn[4]])
                else:
                    op_str = insn[2]

                if self.is_branch_insn(insn[0].mnemonic):
                    branch_addr = int(insn[0].op_str.split(",")[-1].strip(), 0)
                    if branch_addr in self.parent.jtbl_glabels_to_add:
                        label_str = f"L{branch_addr:X}_{self.ram_to_rom(branch_addr):X}"
                        op_str = ", ".join(insn[2].split(", ")[:-1] +
                                           [label_str])

                insn_text = insn[1]
                if indent_next:
                    indent_next = False
                    insn_text = " " + insn_text

                asm_insn_text = "  {}{}".format(
                    insn_text.ljust(mnemonic_ljust), op_str).rstrip()

                func_text.append(asm_comment + asm_insn_text)

                if insn[0].mnemonic != "branch" and insn[0].mnemonic.startswith(
                        "b") or insn[0].mnemonic.startswith("j"):
                    indent_next = True

            if addsuffix:
                func_text.append(f"endlabel {sym.name}")

            ret[func] = (func_text, rom_addr)

            if options.get("find_file_boundaries"):
                # If this is not the last function in the file
                if func != list(funcs.keys())[-1]:

                    # Find where the function returns
                    jr_pos: Optional[int] = None
                    for i, insn in enumerate(reversed(funcs[func])):
                        if insn[0].mnemonic == "jr" and insn[0].op_str == "$ra":
                            jr_pos = i
                            break

                    # If there is more than 1 nop after the return
                    if jr_pos is not None and jr_pos > 1 and self.is_nops(
                        [i[0] for i in funcs[func][-jr_pos + 1:]]):
                        new_file_addr = funcs[func][-1][3] + 4
                        if (new_file_addr % 16) == 0:
                            if not self.parent.reported_file_split:
                                self.parent.reported_file_split = True
                                print(
                                    f"Segment {self.name}, function at vram {func:X} ends with extra nops, indicating a likely file split."
                                )
                                print(
                                    "File split suggestions for this segment will follow in config yaml format:"
                                )
                            print(f"      - [0x{new_file_addr:X}, asm]")

        return ret

    def gather_jumptable_labels(self, rom_bytes):
        # TODO: use the seg_symbols for this
        # jumptables = [j.type == "jtbl" for j in self.seg_symbols]
        for jumptable in self.parent.jumptables:
            start, end = self.parent.jumptables[jumptable]
            rom_offset = self.rom_start + jumptable - self.vram_start

            if rom_offset <= 0:
                return

            while (rom_offset):
                word = rom_bytes[rom_offset:rom_offset + 4]
                word_int = int.from_bytes(word, "big")
                if word_int >= start and word_int <= end:
                    self.parent.jtbl_glabels_to_add.add(word_int)
                else:
                    break

                rom_offset += 4

    def should_scan(self) -> bool:
        return options.mode_active(
            "code") and self.rom_start != "auto" and self.rom_end != "auto"

    def should_split(self) -> bool:
        return self.extract and options.mode_active("code")
Exemplo n.º 10
0
    def disassemble_symbol(self, sym_bytes, sym_type):
        endian = options.get_endianess()
        if sym_type == "jtbl":
            sym_str = ".word "
        else:
            sym_str = f".{sym_type} "

        if sym_type == "double":
            slen = 8
        elif sym_type == "short":
            slen = 2
        elif sym_type == "byte":
            slen = 1
        else:
            slen = 4

        if sym_type == "ascii":
            try:
                ascii_str = sym_bytes.decode("EUC-JP")
                # ascii_str = ascii_str.rstrip("\x00")
                ascii_str = ascii_str.replace("\\", "\\\\")  # escape back slashes
                ascii_str = ascii_str.replace('"', '\\"')  # escape quotes
                ascii_str = ascii_str.replace("\x00", "\\0")
                ascii_str = ascii_str.replace("\n", "\\n")

                sym_str += f'"{ascii_str}"'
                return sym_str
            except:
                return self.disassemble_symbol(sym_bytes, "word")

        i = 0
        while i < len(sym_bytes):
            adv_amt = min(slen, len(sym_bytes) - i)
            bits = int.from_bytes(sym_bytes[i : i + adv_amt], endian)

            if sym_type == "jtbl":
                if bits == 0:
                    byte_str = "0"
                else:
                    rom_addr = self.get_most_parent().ram_to_rom(bits)

                    if rom_addr:
                        byte_str = f"L{bits:X}_{rom_addr:X}"
                    else:
                        byte_str = f"0x{bits:X}"
            elif slen == 4 and bits >= 0x80000000:
                sym = self.get_most_parent().get_symbol(bits, reference=True)
                if sym:
                    byte_str = sym.name
                else:
                    byte_str = "0x{0:0{1}X}".format(bits, 2 * slen)
            else:
                byte_str = "0x{0:0{1}X}".format(bits, 2 * slen)

            if sym_type in ["float", "double"]:
                if sym_type == "float":
                    float_str = floats.format_f32_imm(bits)
                else:
                    float_str = floats.format_f64_imm(bits)

                # Fall back to .word if we see weird float values
                # TODO: cut the symbol in half maybe where we see the first nan or something
                if "e-" in float_str or "nan" in float_str:
                    return self.disassemble_symbol(sym_bytes, "word")
                else:
                    byte_str = float_str

            sym_str += byte_str

            i += adv_amt

            if i < len(sym_bytes):
                sym_str += ", "

        return sym_str
Exemplo n.º 11
0
def configure_disassembler():
    # Configure spimdisasm
    spimdisasm.common.GlobalConfig.PRODUCE_SYMBOLS_PLUS_OFFSET = True
    spimdisasm.common.GlobalConfig.TRUST_USER_FUNCTIONS = True
    spimdisasm.common.GlobalConfig.TRUST_JAL_FUNCTIONS = True
    spimdisasm.common.GlobalConfig.GLABEL_ASM_COUNT = False

    if options.rom_address_padding():
        spimdisasm.common.GlobalConfig.ASM_COMMENT_OFFSET_WIDTH = 6
    else:
        spimdisasm.common.GlobalConfig.ASM_COMMENT_OFFSET_WIDTH = 0

    # spimdisasm is not performing any analyzis on non-text sections so enabling this options is pointless
    spimdisasm.common.GlobalConfig.AUTOGENERATED_NAMES_BASED_ON_SECTION_TYPE = False
    spimdisasm.common.GlobalConfig.AUTOGENERATED_NAMES_BASED_ON_DATA_TYPE = False

    spimdisasm.common.GlobalConfig.SYMBOL_FINDER_FILTERED_ADDRESSES_AS_HILO = False

    rabbitizer.config.regNames_userFpcCsr = False
    rabbitizer.config.regNames_vr4300Cop0NamedRegisters = False

    rabbitizer.config.misc_opcodeLJust = options.mnemonic_ljust() - 1

    rabbitizer.config.regNames_gprAbiNames = rabbitizer.Abi.fromStr(
        options.get_mips_abi_gpr()
    )
    rabbitizer.config.regNames_fprAbiNames = rabbitizer.Abi.fromStr(
        options.get_mips_abi_float_regs()
    )

    if options.get_endianess() == "big":
        spimdisasm.common.GlobalConfig.ENDIAN = spimdisasm.common.InputEndian.BIG
    else:
        spimdisasm.common.GlobalConfig.ENDIAN = spimdisasm.common.InputEndian.LITTLE

    rabbitizer.config.pseudos_pseudoMove = False

    selectedCompiler = options.get_compiler()
    if selectedCompiler == compiler.SN64:
        rabbitizer.config.regNames_namedRegisters = False
        rabbitizer.config.toolchainTweaks_sn64DivFix = True
        rabbitizer.config.toolchainTweaks_treatJAsUnconditionalBranch = True
        spimdisasm.common.GlobalConfig.ASM_COMMENT = False
        spimdisasm.common.GlobalConfig.SYMBOL_FINDER_FILTERED_ADDRESSES_AS_HILO = False
        spimdisasm.common.GlobalConfig.COMPILER = spimdisasm.common.Compiler.SN64
    elif selectedCompiler == compiler.GCC:
        rabbitizer.config.toolchainTweaks_treatJAsUnconditionalBranch = True
        spimdisasm.common.GlobalConfig.COMPILER = spimdisasm.common.Compiler.GCC
    elif selectedCompiler == compiler.IDO:
        spimdisasm.common.GlobalConfig.COMPILER = spimdisasm.common.Compiler.IDO

    spimdisasm.common.GlobalConfig.GP_VALUE = options.get_gp()

    spimdisasm.common.GlobalConfig.ASM_TEXT_LABEL = options.get_asm_function_macro()
    spimdisasm.common.GlobalConfig.ASM_DATA_LABEL = options.get_asm_data_macro()
    spimdisasm.common.GlobalConfig.ASM_TEXT_END_LABEL = options.get_asm_end_label()

    if spimdisasm.common.GlobalConfig.ASM_TEXT_LABEL == ".globl":
        spimdisasm.common.GlobalConfig.ASM_TEXT_ENT_LABEL = ".ent"
        spimdisasm.common.GlobalConfig.ASM_TEXT_FUNC_AS_LABEL = True

    spimdisasm.common.GlobalConfig.LINE_ENDS = options.c_newline()

    if options.get_platform() == "n64":
        symbols.spim_context.fillDefaultBannedSymbols()
Exemplo n.º 12
0
class CommonSegCodeSubsegment(Segment):
    double_mnemonics = ["ldc1", "sdc1"]
    word_mnemonics = ["addiu", "sw", "lw", "jtbl"]
    float_mnemonics = ["lwc1", "swc1"]
    short_mnemonics = ["addiu", "lh", "sh", "lhu"]
    byte_mnemonics = ["lb", "sb", "lbu"]
    reg_numbers = {
        "$zero": "$0",
        "$at": "$1",
        "$v0": "$2",
        "$v1": "$3",
        "$a0": "$4",
        "$a1": "$5",
        "$a2": "$6",
        "$a3": "$7",
        "$t0": "$8",
        "$t1": "$9",
        "$t2": "$10",
        "$t3": "$11",
        "$t4": "$12",
        "$t5": "$13",
        "$t6": "$14",
        "$t7": "$15",
        "$s0": "$16",
        "$s1": "$17",
        "$s2": "$18",
        "$s3": "$19",
        "$s4": "$20",
        "$s5": "$21",
        "$s6": "$22",
        "$s7": "$23",
        "$t8": "$24",
        "$t9": "$25",
        "$k0": "$26",
        "$k1": "$27",
        "$gp": "$28",
        "$sp": "$sp",
        "$fp": "$30",
        "$ra": "$31",
    }

    if options.get_endianess() == "big":
        capstone_mode = CS_MODE_MIPS64 | CS_MODE_BIG_ENDIAN
    else:
        capstone_mode = CS_MODE_MIPS32 | CS_MODE_LITTLE_ENDIAN

    md = Cs(CS_ARCH_MIPS, capstone_mode)
    md.detail = False
    md.skipdata = True

    @property
    def needs_symbols(self) -> bool:
        return True

    def get_linker_section(self) -> str:
        return ".text"

    @staticmethod
    def is_nops(insns: List[CsInsn]) -> bool:
        for insn in insns:
            if insn.mnemonic != "nop":
                return False
        return True

    @staticmethod
    def is_branch_insn(mnemonic):
        return (mnemonic.startswith("b") and not mnemonic.startswith("binsl")
                and not mnemonic == "break") or mnemonic == "j"

    @staticmethod
    def replace_reg_names(op_str):
        for regname, regnum in CommonSegCodeSubsegment.reg_numbers.items():
            op_str = op_str.replace(regname, regnum)
        return op_str

    def scan_code(self, rom_bytes, is_asm=False):
        insns: List[CsInsn] = [
            insn for insn in CommonSegCodeSubsegment.md.disasm(
                rom_bytes[self.rom_start:self.rom_end], self.vram_start)
        ]

        self.funcs: typing.OrderedDict[int, Symbol] = self.process_insns(
            insns, self.rom_start, is_asm=is_asm)

        # TODO: set these in creation
        for func in self.funcs.values():
            func.define = True
            func.local_only = True
            func.size = len(func.insns) * 4

        self.determine_symbols()

    def split_code(self, rom_bytes):
        self.gather_jumptable_labels(rom_bytes)
        return self.add_labels()

    def process_insns(self,
                      insns: List[CsInsn],
                      rom_addr,
                      is_asm=False) -> typing.OrderedDict[int, Symbol]:
        assert isinstance(self.parent, CommonSegCode)
        self.parent: CommonSegCode = self.parent

        ret: typing.OrderedDict[int, Symbol] = OrderedDict()

        end_func = False
        start_new_func = True
        labels = []

        big_endian = options.get_endianess() == "big"

        # Collect labels
        for insn in insns:
            if self.is_branch_insn(insn.mnemonic):
                op_str_split = insn.op_str.split(" ")
                branch_target = op_str_split[-1]
                branch_addr = int(branch_target, 0)
                labels.append((insn.address, branch_addr))

        # Main loop
        for i, insn in enumerate(insns):
            mnemonic = insn.mnemonic
            op_str = insn.op_str

            # If this is non-zero, disasm size insns
            hard_size = 0

            if start_new_func:
                func: Symbol = self.parent.create_symbol(insn.address,
                                                         type="func")
                start_new_func = False

            if func.size > 4:
                hard_size = func.size / 4

            if options.get_compiler() == SN64:
                op_str = self.replace_reg_names(op_str)

            if mnemonic == "move":
                # Let's get the actual instruction out
                idx = 3 if big_endian else 0
                opcode = insn.bytes[idx] & 0b00111111

                if options.get_compiler() == SN64:
                    op_str += ", $0"
                else:
                    op_str += ", $zero"

                if opcode == 37:
                    mnemonic = "or"
                elif opcode == 45:
                    mnemonic = "daddu"
                elif opcode == 33:
                    mnemonic = "addu"
                else:
                    print("INVALID INSTRUCTION " + str(insn), opcode)
            elif mnemonic == "jal":
                jal_addr = int(op_str, 0)
                jump_func = self.parent.create_symbol(jal_addr,
                                                      type="func",
                                                      reference=True)
                op_str = jump_func.name
            elif self.is_branch_insn(insn.mnemonic):
                op_str_split = op_str.split(" ")
                branch_target = op_str_split[-1]
                branch_target_int = int(branch_target, 0)

                label_sym = self.parent.get_symbol(branch_target_int,
                                                   type="label",
                                                   reference=True,
                                                   local_only=True)

                if label_sym:
                    label_name = label_sym.name
                else:
                    self.parent.labels_to_add.add(branch_target_int)
                    label_name = f".L{branch_target[2:].upper()}"

                op_str = " ".join(op_str_split[:-1] + [label_name])
            elif mnemonic in ["mtc0", "mfc0", "mtc2", "mfc2"]:
                idx = 2 if big_endian else 1
                rd = (insn.bytes[idx] & 0xF8) >> 3
                op_str = op_str.split(" ")[0] + " $" + str(rd)
            elif (mnemonic == "break" and op_str in ["6", "7"]
                  and options.get_compiler() == SN64 and not is_asm):
                # SN64's assembler expands div to have break if dividing by zero
                # However, the break it generates is different than the one it generates with `break N`
                # So we replace break instrutions for SN64 with the exact word that the assembler generates when expanding div
                if op_str == "6":
                    mnemonic = ".word 0x0006000D"
                    op_str = ""
                elif op_str == "7":
                    mnemonic = ".word 0x0007000D"
                    op_str = ""
            elif (mnemonic in ["div", "divu"]
                  and options.get_compiler() == SN64 and not is_asm):
                # SN64's assembler also doesn't like assembling `div $0, a, b` with .set noat active
                # Removing the $0 fixes this issue
                if op_str.startswith("$0, "):
                    op_str = op_str[4:]

            func.insns.append(Instruction(insn, mnemonic, op_str, rom_addr))
            rom_addr += 4

            size_remaining = hard_size - len(
                func.insns) if hard_size > 0 else 0

            if mnemonic == "jr":
                # Record potential jtbl jumps
                if op_str not in ["$ra", "$31"]:
                    self.parent.jtbl_jumps[insn.address] = op_str

                keep_going = False
                for label in labels:
                    if (label[0] > insn.address and label[1] <= insn.address
                        ) or (label[0] <= insn.address
                              and label[1] > insn.address):
                        keep_going = True
                        break
                if not keep_going and not size_remaining:
                    end_func = True
                    continue

            # Stop here if a size was specified and we have disassembled up to the size
            if hard_size > 0 and size_remaining == 0:
                end_func = True

            if i < len(insns) - 1 and self.parent.get_symbol(
                    insns[i + 1].address, local_only=True, type="func",
                    dead=False):
                end_func = True

            if end_func:
                if (self.is_nops(insns[i:]) or i < len(insns) - 1
                        and insns[i + 1].mnemonic != "nop"):
                    end_func = False
                    start_new_func = True
                    ret[func.vram_start] = func

        # Add the last function (or append nops to the previous one)
        if not self.is_nops([insn.instruction for insn in func.insns]):
            ret[func.vram_start] = func
        else:
            next(reversed(ret.values())).insns.extend(func.insns)

        return ret

    def update_access_mnemonic(self, sym, mnemonic):
        if not sym.access_mnemonic:
            sym.access_mnemonic = mnemonic
        elif sym.access_mnemonic == "addiu":
            sym.access_mnemonic = mnemonic
        elif sym.access_mnemonic in self.double_mnemonics:
            return
        elif (sym.access_mnemonic in self.float_mnemonics
              and mnemonic in self.double_mnemonics):
            sym.access_mnemonic = mnemonic
        elif sym.access_mnemonic in self.short_mnemonics:
            return
        elif sym.access_mnemonic in self.byte_mnemonics:
            return
        else:
            sym.access_mnemonic = mnemonic

    # Determine symbols
    def determine_symbols(self):
        hi_lo_max_distance = options.hi_lo_max_distance()

        for func_addr in self.funcs:
            func = self.funcs[func_addr]
            func_end_addr = func.insns[-1].instruction.address + 4

            possible_jtbl_jumps = [(k, v)
                                   for k, v in self.parent.jtbl_jumps.items()
                                   if k >= func_addr and k < func_end_addr]
            possible_jtbl_jumps.sort(key=lambda x: x[0])

            for i in range(len(func.insns)):
                hi_insn: CsInsn = func.insns[i].instruction

                # Ensure the first item in the list is always ahead of where we're looking
                while (len(possible_jtbl_jumps) > 0
                       and possible_jtbl_jumps[0][0] < hi_insn.address):
                    del possible_jtbl_jumps[0]

                # Find gp relative reads and writes e.g  lw $a1, 0x670($gp)
                if hi_insn.op_str.endswith("($gp)"):
                    gp_base = options.get_gp()
                    if gp_base is None:
                        log.error(
                            "gp_value not set in yaml, can't calculate %gp_rel reloc value for "
                            + hi_insn.op_str)

                    op_split = hi_insn.op_str.split(", ")
                    gp_offset = op_split[1][:-5]  # extract the 0x670 part
                    if len(gp_offset) == 0:
                        gp_offset = 0
                    else:
                        gp_offset = int(gp_offset, 16)
                    symbol_addr = gp_base + gp_offset

                    sym = self.parent.create_symbol(symbol_addr,
                                                    offsets=True,
                                                    reference=True)
                    offset = symbol_addr - sym.vram_start
                    offset_str = f"+0x{offset:X}"

                    if self.parent:
                        self.parent.check_rodata_sym(func_addr, sym)

                    self.update_access_mnemonic(sym, hi_insn.mnemonic)

                    func.insns[i].is_gp = True
                    func.insns[i].hi_lo_sym = sym
                    func.insns[i].sym_offset_str = offset_str
                # All hi/lo pairs start with a lui
                elif hi_insn.mnemonic == "lui":
                    op_split = hi_insn.op_str.split(", ")
                    hi_reg = op_split[0]

                    if not op_split[1].startswith("0x"):
                        continue

                    lui_val = int(op_split[1], 0)

                    # Assumes all luis are going to load from 0x80000000 or higher (maybe false)
                    if lui_val >= 0x8000:
                        # Iterate over the next few instructions to see if we can find a matching lo
                        for j in range(
                                i + 1,
                                min(i + hi_lo_max_distance, len(func.insns))):
                            lo_insn = func.insns[j].instruction

                            s_op_split = lo_insn.op_str.split(", ")

                            if lo_insn.mnemonic == "lui" and hi_reg == s_op_split[
                                    0]:
                                break

                            if lo_insn.mnemonic in ["addiu", "ori"]:
                                lo_reg = s_op_split[-2]
                            else:
                                lo_reg = s_op_split[-1][s_op_split[-1].
                                                        rfind("(") + 1:-1]

                            if hi_reg == lo_reg:
                                if lo_insn.mnemonic not in [
                                        "addiu",
                                        "lw",
                                        "sw",
                                        "lh",
                                        "sh",
                                        "lhu",
                                        "lb",
                                        "sb",
                                        "lbu",
                                        "lwc1",
                                        "swc1",
                                        "ldc1",
                                        "sdc1",
                                ]:
                                    break

                                # Match!
                                reg_ext = ""

                                # I forgot what this is doing
                                junk_search = re.search(
                                    r"[\(]", s_op_split[-1])
                                if junk_search is not None:
                                    if junk_search.start() == 0:
                                        break
                                    s_str = s_op_split[-1][:junk_search.start(
                                    )]
                                    reg_ext = s_op_split[-1][junk_search.start(
                                    ):]
                                else:
                                    s_str = s_op_split[-1]

                                if options.get_compiler() == SN64:
                                    reg_ext = CommonSegCodeSubsegment.replace_reg_names(
                                        reg_ext)

                                symbol_addr = (lui_val * 0x10000) + int(
                                    s_str, 0)

                                sym: Optional[Symbol] = None
                                offset_str = ""

                                # If the symbol is likely in the rodata section
                                if (not self.parent.text_follows_rodata
                                        and symbol_addr > func_addr) or (
                                            self.parent.text_follows_rodata
                                            and symbol_addr < func_addr):
                                    # Sanity check that the symbol is within this segment's vram
                                    if (self.parent.vram_end and symbol_addr <
                                            self.parent.vram_end):
                                        # If we've seen possible jumps to a jumptable and this symbol isn't too close to the end of the function
                                        if (len(possible_jtbl_jumps) > 0
                                                and func_end_addr -
                                                lo_insn.address >= 0x30):
                                            for jump in possible_jtbl_jumps:
                                                if jump[1] == s_op_split[0]:
                                                    dist_to_jump = (
                                                        possible_jtbl_jumps[0]
                                                        [0] - lo_insn.address)
                                                    if dist_to_jump <= 16:
                                                        sym = self.parent.create_symbol(
                                                            symbol_addr,
                                                            reference=True,
                                                            type="jtbl",
                                                            local_only=True,
                                                        )

                                                        self.parent.jumptables[
                                                            symbol_addr] = (
                                                                func_addr,
                                                                func_end_addr)
                                                        break

                                if not sym:
                                    sym = self.parent.create_symbol(
                                        symbol_addr,
                                        offsets=True,
                                        reference=True)
                                    offset = symbol_addr - sym.vram_start
                                    if offset != 0:
                                        offset_str = f"+0x{offset:X}"

                                if self.parent:
                                    self.parent.check_rodata_sym(
                                        func_addr, sym)

                                self.update_access_mnemonic(
                                    sym, lo_insn.mnemonic)

                                func.insns[i].is_hi = True
                                func.insns[i].hi_lo_sym = sym
                                func.insns[i].sym_offset_str = offset_str

                                func.insns[j].is_lo = True
                                func.insns[j].hi_lo_sym = sym
                                func.insns[j].sym_offset_str = offset_str
                                func.insns[j].hi_lo_reg = reg_ext
                                break

    def add_labels(self):
        ret = {}

        function_macro = options.get_asm_function_macro()
        data_macro = options.get_asm_data_macro()

        for func_addr in self.funcs:
            func_text = []
            func = self.funcs[func_addr]

            # Add function label
            func_text.append(f"{function_macro} {func.name}")

            if options.get_compiler() == SN64:
                func_text.append(f".ent {func.name}")
                func_text.append(f"{func.name}:")

            indent_next = False

            mnemonic_ljust = options.mnemonic_ljust()
            rom_addr_padding = options.rom_address_padding()

            for insn in func.insns:
                insn_addr = insn.instruction.address
                # Add a label if we need one
                if insn_addr in self.parent.jtbl_glabels_to_add:
                    func_text.append(
                        f"{data_macro} L{insn_addr:X}_{insn.rom_addr:X}")
                elif insn_addr in self.parent.labels_to_add:
                    self.parent.labels_to_add.remove(insn_addr)
                    func_text.append(".L{:X}:".format(insn_addr))

                if rom_addr_padding:
                    rom_str = "{0:0{1}X}".format(insn.rom_addr,
                                                 rom_addr_padding)
                else:
                    rom_str = "{:X}".format(insn.rom_addr)

                if options.get_compiler() == SN64:
                    asm_comment = ""
                else:
                    asm_comment = "/* {} {:X} {} */".format(
                        rom_str, insn_addr,
                        insn.instruction.bytes.hex().upper())

                if insn.is_hi:
                    assert insn.hi_lo_sym
                    op_str = ", ".join(
                        insn.op_str.split(", ")[:-1] +
                        [f"%hi({insn.hi_lo_sym.name}{insn.sym_offset_str})"])
                elif insn.is_lo:
                    assert insn.hi_lo_sym
                    op_str = ", ".join(
                        insn.op_str.split(", ")[:-1] + [
                            f"%lo({insn.hi_lo_sym.name}{insn.sym_offset_str}){insn.hi_lo_reg}"
                        ])
                elif insn.is_gp:
                    op_str = ", ".join(
                        insn.op_str.split(", ")[:-1] + [
                            f"%gp_rel({insn.hi_lo_sym.name}{insn.sym_offset_str})($gp)"
                        ])
                else:
                    op_str = insn.op_str

                if self.is_branch_insn(insn.instruction.mnemonic):
                    branch_addr = int(
                        insn.instruction.op_str.split(",")[-1].strip(), 0)
                    if branch_addr in self.parent.jtbl_glabels_to_add:
                        label_str = f"L{branch_addr:X}_{self.ram_to_rom(branch_addr):X}"
                        op_str = ", ".join(
                            insn.op_str.split(", ")[:-1] + [label_str])

                insn_text = insn.mnemonic
                if indent_next:
                    indent_next = False
                    insn_text = " " + insn_text

                asm_insn_text = "  {}{}".format(
                    insn_text.ljust(mnemonic_ljust), op_str).rstrip()

                func_text.append(asm_comment + asm_insn_text)

                if (insn.instruction.mnemonic != "branch"
                        and insn.instruction.mnemonic.startswith("b")
                        or insn.instruction.mnemonic.startswith("j")):
                    indent_next = True

            end_label = options.get_asm_end_label()

            if end_label:
                func_text.append(f"{end_label} {func.name}")

            ret[func_addr] = (func_text, func.rom)

            if options.find_file_boundaries():
                # If this is not the last function in the file
                if func_addr != list(self.funcs.keys())[-1]:

                    # Find where the function returns
                    jr_pos: Optional[int] = None
                    for i, insn in enumerate(reversed(func.insns)):
                        if (insn.instruction.mnemonic == "jr"
                                and insn.instruction.op_str in ["$ra", "$31"]):
                            jr_pos = i
                            break

                    # If there is more than 1 nop after the return
                    if (jr_pos is not None and jr_pos > 1 and self.is_nops([
                            insn.instruction
                            for insn in func.insns[-jr_pos + 1:]
                    ])):
                        new_file_addr = func.insns[-1].rom_addr + 4
                        if (new_file_addr % 16) == 0:
                            if not self.parent.reported_file_split:
                                self.parent.reported_file_split = True
                                print(
                                    f"Segment {self.name}, function at vram {func_addr:X} ends with extra nops, indicating a likely file split."
                                )
                                print(
                                    "File split suggestions for this segment will follow in config yaml format:"
                                )
                            print(f"      - [0x{new_file_addr:X}, asm]")

        return ret

    def gather_jumptable_labels(self, rom_bytes):
        # TODO: use the seg_symbols for this
        # jumptables = [j.type == "jtbl" for j in self.seg_symbols]
        for jumptable in self.parent.jumptables:
            start, end = self.parent.jumptables[jumptable]
            rom_offset = self.rom_start + jumptable - self.vram_start

            if rom_offset <= 0:
                return

            while rom_offset:
                word = rom_bytes[rom_offset:rom_offset + 4]
                word_int = int.from_bytes(word, options.get_endianess())
                if word_int >= start and word_int <= end:
                    self.parent.jtbl_glabels_to_add.add(word_int)
                else:
                    break

                rom_offset += 4

    def should_scan(self) -> bool:
        return (options.mode_active("code") and self.rom_start != "auto"
                and self.rom_end != "auto")

    def should_split(self) -> bool:
        return self.extract and options.mode_active("code")
Exemplo n.º 13
0
    def process_insns(self,
                      insns: List[CsInsn],
                      rom_addr,
                      is_asm=False) -> typing.OrderedDict[int, Symbol]:
        assert isinstance(self.parent, CommonSegCode)
        self.parent: CommonSegCode = self.parent

        ret: typing.OrderedDict[int, Symbol] = OrderedDict()

        end_func = False
        start_new_func = True
        labels = []

        big_endian = options.get_endianess() == "big"

        # Collect labels
        for insn in insns:
            if self.is_branch_insn(insn.mnemonic):
                op_str_split = insn.op_str.split(" ")
                branch_target = op_str_split[-1]
                branch_addr = int(branch_target, 0)
                labels.append((insn.address, branch_addr))

        # Main loop
        for i, insn in enumerate(insns):
            mnemonic = insn.mnemonic
            op_str = insn.op_str

            # If this is non-zero, disasm size insns
            hard_size = 0

            if start_new_func:
                func: Symbol = self.parent.create_symbol(insn.address,
                                                         type="func")
                start_new_func = False

            if func.size > 4:
                hard_size = func.size / 4

            if options.get_compiler() == SN64:
                op_str = self.replace_reg_names(op_str)

            if mnemonic == "move":
                # Let's get the actual instruction out
                idx = 3 if big_endian else 0
                opcode = insn.bytes[idx] & 0b00111111

                if options.get_compiler() == SN64:
                    op_str += ", $0"
                else:
                    op_str += ", $zero"

                if opcode == 37:
                    mnemonic = "or"
                elif opcode == 45:
                    mnemonic = "daddu"
                elif opcode == 33:
                    mnemonic = "addu"
                else:
                    print("INVALID INSTRUCTION " + str(insn), opcode)
            elif mnemonic == "jal":
                jal_addr = int(op_str, 0)
                jump_func = self.parent.create_symbol(jal_addr,
                                                      type="func",
                                                      reference=True)
                op_str = jump_func.name
            elif self.is_branch_insn(insn.mnemonic):
                op_str_split = op_str.split(" ")
                branch_target = op_str_split[-1]
                branch_target_int = int(branch_target, 0)

                label_sym = self.parent.get_symbol(branch_target_int,
                                                   type="label",
                                                   reference=True,
                                                   local_only=True)

                if label_sym:
                    label_name = label_sym.name
                else:
                    self.parent.labels_to_add.add(branch_target_int)
                    label_name = f".L{branch_target[2:].upper()}"

                op_str = " ".join(op_str_split[:-1] + [label_name])
            elif mnemonic in ["mtc0", "mfc0", "mtc2", "mfc2"]:
                idx = 2 if big_endian else 1
                rd = (insn.bytes[idx] & 0xF8) >> 3
                op_str = op_str.split(" ")[0] + " $" + str(rd)
            elif (mnemonic == "break" and op_str in ["6", "7"]
                  and options.get_compiler() == SN64 and not is_asm):
                # SN64's assembler expands div to have break if dividing by zero
                # However, the break it generates is different than the one it generates with `break N`
                # So we replace break instrutions for SN64 with the exact word that the assembler generates when expanding div
                if op_str == "6":
                    mnemonic = ".word 0x0006000D"
                    op_str = ""
                elif op_str == "7":
                    mnemonic = ".word 0x0007000D"
                    op_str = ""
            elif (mnemonic in ["div", "divu"]
                  and options.get_compiler() == SN64 and not is_asm):
                # SN64's assembler also doesn't like assembling `div $0, a, b` with .set noat active
                # Removing the $0 fixes this issue
                if op_str.startswith("$0, "):
                    op_str = op_str[4:]

            func.insns.append(Instruction(insn, mnemonic, op_str, rom_addr))
            rom_addr += 4

            size_remaining = hard_size - len(
                func.insns) if hard_size > 0 else 0

            if mnemonic == "jr":
                # Record potential jtbl jumps
                if op_str not in ["$ra", "$31"]:
                    self.parent.jtbl_jumps[insn.address] = op_str

                keep_going = False
                for label in labels:
                    if (label[0] > insn.address and label[1] <= insn.address
                        ) or (label[0] <= insn.address
                              and label[1] > insn.address):
                        keep_going = True
                        break
                if not keep_going and not size_remaining:
                    end_func = True
                    continue

            # Stop here if a size was specified and we have disassembled up to the size
            if hard_size > 0 and size_remaining == 0:
                end_func = True

            if i < len(insns) - 1 and self.parent.get_symbol(
                    insns[i + 1].address, local_only=True, type="func",
                    dead=False):
                end_func = True

            if end_func:
                if (self.is_nops(insns[i:]) or i < len(insns) - 1
                        and insns[i + 1].mnemonic != "nop"):
                    end_func = False
                    start_new_func = True
                    ret[func.vram_start] = func

        # Add the last function (or append nops to the previous one)
        if not self.is_nops([insn.instruction for insn in func.insns]):
            ret[func.vram_start] = func
        else:
            next(reversed(ret.values())).insns.extend(func.insns)

        return ret