def set_added_segment_headers(self, nsegments):
        assert self.ncontent[0x34:0x34+len(self.patched_tag)] == self.patched_tag
        if self.data_fallback:
            l.debug("added_data_file_start: %#x", self.added_data_file_start)
        added_segments = 0
        original_nsegments = nsegments

        # if the size of a segment is zero, the kernel does not allocate any memory
        # so, we don't care about empty segments
        if self.data_fallback:
            mem_data_location = self.name_map["ADDED_DATA_START"]
            data_segment_header = (1, self.added_data_file_start, mem_data_location, mem_data_location,
                                   len(self.added_data), len(self.added_data), 0x6, 0x1000)  # RW
            self.ncontent = utils.str_overwrite(self.ncontent, struct.pack("<IIIIIIII", *data_segment_header),
                                                self.original_header_end + 32)
            added_segments += 1
        else:
            pass
            # in this case the header has been already patched before

        mem_code_location = self.added_code_segment + (self.added_code_file_start % 0x1000)
        code_segment_header = (1, self.added_code_file_start, mem_code_location, mem_code_location,
                               len(self.added_code), len(self.added_code), 0x5, 0x1000)  # RX
        self.ncontent = utils.str_overwrite(self.ncontent, struct.pack("<IIIIIIII", *code_segment_header),
                                            self.original_header_end)
        added_segments += 1

        # print original_nsegments,added_segments
        self.ncontent = utils.str_overwrite(self.ncontent, struct.pack("<H", original_nsegments + added_segments), 0x2c)
Exemple #2
0
    def _add_segments(self, filename, patches):
        fp = open(filename)
        content = fp.read()
        fp.close()

        # dump the original segments
        old_segments = []
        header_size = 16 + 2 * 2 + 4 * 5 + 2 * 6
        buf = content[0:header_size]
        (cgcef_type, cgcef_machine, cgcef_version, cgcef_entry, cgcef_phoff,
         cgcef_shoff, cgcef_flags, cgcef_ehsize, cgcef_phentsize, cgcef_phnum,
         cgcef_shentsize, cgcef_shnum,
         cgcef_shstrndx) = struct.unpack("<xxxxxxxxxxxxxxxxHHLLLLLHHHHHH", buf)
        phent_size = 8 * 4
        assert cgcef_phnum != 0
        assert cgcef_phentsize == phent_size
        pt_types = {
            0: "NULL",
            1: "LOAD",
            6: "PHDR",
            0x60000000 + 0x474e551: "GNU_STACK",
            0x6ccccccc: "CGCPOV2"
        }
        segments = []
        for i in xrange(0, cgcef_phnum):
            hdr = content[cgcef_phoff + phent_size * i:cgcef_phoff +
                          phent_size * i + phent_size]
            (p_type, p_offset, p_vaddr, p_paddr, p_filesz, p_memsz, p_flags,
             p_align) = struct.unpack("<IIIIIIII", hdr)
            assert p_type in pt_types
            old_segments.append((p_type, p_offset, p_vaddr, p_paddr, p_filesz,
                                 p_memsz, p_flags, p_align))

        # align size of the entire ELF
        content = utils.pad_str(content, 0x10)
        # change pointer to program headers to point at the end of the elf
        content = utils.str_overwrite(content, struct.pack("<I", len(content)),
                                      0x1C)

        new_segments = [p.new_segment for p in patches]
        all_segments = old_segments + new_segments

        # add all segments at the end of the file
        for segment in all_segments:
            content = utils.str_overwrite(content,
                                          struct.pack("<IIIIIIII", *segment))

        # we overwrite the first original program header,
        # we do not need it anymore since we have moved original program headers at the bottom of the file
        content = utils.str_overwrite(content, "SHELLPHISH\x00", 0x34)

        # set the total number of segment headers
        content = utils.str_overwrite(content,
                                      struct.pack("<H", len(all_segments)),
                                      0x2c)

        # update the file
        fp = open(filename, "wb")
        fp.write(content)
        fp.close()
    def remove_pdf(self, pdf_start, pdf_length, check_instruction_addr, check_instruction_size):
        l.info("pdf is between %08x and %08x" % (pdf_start,  pdf_start + pdf_length))
        last_segment = self.modded_segments[-1]
        cut_end = (pdf_start+pdf_length) & 0xfffff000
        cut_start = (pdf_start & 0xfffff000) + 0x1000
        cut_size = cut_end-cut_start
        cut_start_mem = cut_start - last_segment[1] + last_segment[2]
        cut_end_mem = cut_end - last_segment[1] + last_segment[2]
        l.info("cutting the pdf from: %08x to %08x" % (cut_start,cut_end))
        self.ncontent = self.ocontent[:cut_start]+self.ocontent[cut_end:]

        # remove pointer to section headers, so that it loads fine in gdb, ida, ...
        # later we can set this ti 0xffffffff for adversarial patching
        # e_shoff = 0xffffffff, e_shnum = 0x0000,  e_shstrndx = 0x0000
        self.ncontent = utils.str_overwrite(self.ncontent,struct.pack("<I",0),0x20)
        self.ncontent = utils.str_overwrite(self.ncontent,struct.pack("<H",0),0x30)
        self.ncontent = utils.str_overwrite(self.ncontent,struct.pack("<H",0),0x32)

        self.ncontent = utils.str_overwrite(self.ncontent,"\x90"*check_instruction_size, \
                self.maddress_to_baddress(check_instruction_addr))
        self.max_convertible_address = cut_start_mem

        header_size = 16 + 2*2 + 4*5 + 2*6
        buf = self.ocontent[0:header_size]
        (cgcef_type, cgcef_machine, cgcef_version, cgcef_entry, cgcef_phoff,
            cgcef_shoff, cgcef_flags, cgcef_ehsize, cgcef_phentsize, cgcef_phnum,
            cgcef_shentsize, cgcef_shnum, cgcef_shstrndx) = struct.unpack("<xxxxxxxxxxxxxxxxHHLLLLLHHHHHH", buf)
        phent_size = 8 * 4
        assert cgcef_phnum != 0
        assert cgcef_phentsize == phent_size

        segments = self.modded_segments
        last_segment = segments[-1]
        (p_type, p_offset, p_vaddr, p_paddr, p_filesz, p_memsz, p_flags, p_align) = last_segment
        pre_cut_segment_size = cut_start_mem - p_vaddr
        # print map(hex,[cut_start,cut_end,cut_start_mem,pre_cut_segment_size, cut_start_mem, p_vaddr])
        pre_cut_segment = (p_type, p_offset, p_vaddr, p_paddr, pre_cut_segment_size, pre_cut_segment_size, \
                p_flags, p_align)
        post_cut_segment = (p_type, p_offset + pre_cut_segment_size, \
                p_vaddr + pre_cut_segment_size + cut_size, p_vaddr + pre_cut_segment_size + cut_size, \
                p_filesz - cut_size - pre_cut_segment_size, p_memsz - cut_size - pre_cut_segment_size, \
                p_flags, p_align)

        l.info("last segment changed from \n%s to \n%s\n%s" % \
                (map(hex,last_segment),map(hex,pre_cut_segment),map(hex,post_cut_segment)))
        self.modded_segments = segments[:-1] + [pre_cut_segment,post_cut_segment]
    def patch_bin(self, address, new_content):
        # since the content could theoretically be split into different segments we will handle it here
        ndata_pos = 0

        for start, end in self.get_memory_translation_list(address, len(new_content)):
            # print "-",hex(start),hex(end)
            ndata = new_content[ndata_pos:ndata_pos+(end-start)]
            self.ncontent = utils.str_overwrite(self.ncontent, ndata, start)
            ndata_pos += len(ndata)
    def setup_headers(self,segments):
        if self.is_patched():
            return

        # align size of the entire ELF
        self.ncontent = utils.pad_str(self.ncontent, 0x10)
        # change pointer to program headers to point at the end of the elf
        self.ncontent = utils.str_overwrite(self.ncontent, struct.pack("<I", len(self.ncontent)), 0x1C)

        # copying original program headers (potentially modified by patches and/or pdf removal) 
        # in the new place (at the  end of the file)
        for segment in segments:
            self.ncontent = utils.str_overwrite(self.ncontent, struct.pack("<IIIIIIII", *segment))
        self.original_header_end = len(self.ncontent)

        # we overwrite the first original program header,
        # we do not need it anymore since we have moved original program headers at the bottom of the file
        self.ncontent = utils.str_overwrite(self.ncontent, self.patched_tag, 0x34)

        # adding space for the additional headers
        # I add two of them, no matter what, if the data one will be used only in case of the fallback solution
        # Additionally added program headers have been already copied by the for loop above
        self.ncontent = self.ncontent.ljust(len(self.ncontent)+self.additional_headers_size, "\x00")
    def apply_patches(self, patches):
        # deal with stackable patches
        # add stackable patches to the one with highest priority
        insert_code_patches = [p for p in patches if isinstance(p, InsertCodePatch)]
        insert_code_patches_dict = defaultdict(list)
        for p in insert_code_patches:
            insert_code_patches_dict[p.addr].append(p)
        insert_code_patches_dict_sorted = defaultdict(list)
        for k,v in insert_code_patches_dict.iteritems():
            insert_code_patches_dict_sorted[k] = sorted(v,key=lambda x:-1*x.priority)

        insert_code_patches_stackable = [p for p in patches if isinstance(p, InsertCodePatch) and p.stackable]
        for sp in insert_code_patches_stackable:
            assert len(sp.dependencies) == 0
            if sp.addr in insert_code_patches_dict_sorted:
                highest_priority_at_addr = insert_code_patches_dict_sorted[sp.addr][0]
                if highest_priority_at_addr != sp:
                    highest_priority_at_addr.asm_code += "\n"+sp.asm_code+"\n"
                    patches.remove(sp)

        #deal with AddLabel patches
        lpatches = [p for p in patches if (isinstance(p, AddLabelPatch))]
        for p in lpatches:
            self.name_map[p.name] = p.addr

        # check for duplicate labels, it is not very necessary for this backend
        # but it is better to behave in the same way of the reassembler backend
        relevant_patches = [p for p in patches if (isinstance(p, AddCodePatch) or \
                isinstance(p, InsertCodePatch) or isinstance(p, AddEntryPointPatch))]
        all_code = ""
        for p in relevant_patches:
            if isinstance(p, InsertCodePatch):
                code = p.code
            else:
                code = p.asm_code
            all_code += "\n"+code+"\n"
        labels = utils.string_to_labels(all_code)
        duplicates = set([x for x in labels if labels.count(x) > 1])
        if len(duplicates) > 1:
            raise DuplicateLabelsException("found duplicate assembly labels: %s" % (str(duplicates)))

        # for now any added code will be executed by jumping out and back ie CGRex
        # apply all add code patches
        self.added_code_file_start = len(self.ncontent)
        self.name_map.force_insert("ADDED_CODE_START",(len(self.ncontent) % 0x1000) + self.added_code_segment)

        # 0) RawPatch:
        for patch in patches:
            if isinstance(patch, RawFilePatch):
                self.ncontent = utils.str_overwrite(self.ncontent,patch.data,patch.file_addr)
                self.added_patches.append(patch)
                l.info("Added patch: " + str(patch))
        for patch in patches:
            if isinstance(patch, RawMemPatch):
                self.patch_bin(patch.addr,patch.data)
                self.added_patches.append(patch)
                l.info("Added patch: " + str(patch))

        if self.data_fallback:
            # 1)
            self.added_data_file_start = len(self.ncontent)
            curr_data_position = self.name_map["ADDED_DATA_START"]
            for patch in patches:
                if isinstance(patch, AddRWDataPatch) or isinstance(patch, AddRODataPatch) or \
                        isinstance(patch, AddRWInitDataPatch):
                    if hasattr(patch,"data"):
                        final_patch_data = patch.data
                    else:
                        final_patch_data = "\x00"*patch.len
                    self.added_data += final_patch_data
                    if patch.name is not None:
                        self.name_map[patch.name] = curr_data_position
                    curr_data_position += len(final_patch_data)
                    self.ncontent = utils.str_overwrite(self.ncontent, final_patch_data)
                    self.added_patches.append(patch)
                    l.info("Added patch: " + str(patch))
            self.ncontent = utils.pad_str(self.ncontent, 0x10)  # some minimal alignment may be good

            self.added_code_file_start = len(self.ncontent)
            self.name_map.force_insert("ADDED_CODE_START",(len(self.ncontent) % 0x1000) + self.added_code_segment)
        else:
            # 1.1) AddRWDataPatch
            for patch in patches:
                if isinstance(patch, AddRWDataPatch):
                    if patch.name is not None:
                        self.name_map[patch.name] = self.name_map["ADDED_DATA_START"] + self.added_rwdata_len
                    self.added_rwdata_len += patch.len
                    self.added_patches.append(patch)
                    l.info("Added patch: " + str(patch))

            # 1.2) AddRWInitDataPatch
            for patch in patches:
                if isinstance(patch, AddRWInitDataPatch):
                    self.to_init_data += patch.data
                    if patch.name is not None:
                        self.name_map[patch.name] = self.name_map["ADDED_DATA_START"] + self.added_rwdata_len + \
                                self.added_rwinitdata_len
                    self.added_rwinitdata_len += len(patch.data)
                    self.added_patches.append(patch)
                    l.info("Added patch: " + str(patch))
            if self.to_init_data != "":
                code = '''
                jmp _skip_data
                _to_init_data:
                    db %s
                _skip_data:
                    mov esi, _to_init_data
                    mov edi, %s
                    mov ecx, %d
                    cld 
                    rep movsb
                ''' % (",".join([hex(ord(x)) for x in self.to_init_data]), \
                        hex(self.name_map["ADDED_DATA_START"] + self.added_rwdata_len), \
                        self.added_rwinitdata_len)
                patches.append(AddEntryPointPatch(code,priority=1000,name="INIT_DATA"))

            # 1.3) AddRODataPatch
            for patch in patches:
                if isinstance(patch, AddRODataPatch):
                    self.to_init_data += patch.data
                    if patch.name is not None:
                        self.name_map[patch.name] = self.get_current_code_position()
                    self.added_code += patch.data
                    self.ncontent = utils.str_overwrite(self.ncontent, patch.data)
                    self.added_patches.append(patch)
                    l.info("Added patch: " + str(patch))

        # 2) AddCodePatch
        # resolving symbols
        current_symbol_pos = self.get_current_code_position()
        for patch in patches:
            if isinstance(patch, AddCodePatch):
                if patch.is_c:
                    code_len = len(utils.compile_c(patch.asm_code,optimization=patch.optimization))
                else:
                    code_len = len(utils.compile_asm_fake_symbol(patch.asm_code, current_symbol_pos))
                if patch.name is not None:
                    self.name_map[patch.name] = current_symbol_pos
                current_symbol_pos += code_len
        # now compile for real
        for patch in patches:
            if isinstance(patch, AddCodePatch):
                if patch.is_c:
                    new_code = utils.compile_c(patch.asm_code,optimization=patch.optimization)
                else:
                    new_code = utils.compile_asm(patch.asm_code, self.get_current_code_position(), self.name_map)
                self.added_code += new_code
                self.ncontent = utils.str_overwrite(self.ncontent, new_code)
                self.added_patches.append(patch)
                l.info("Added patch: " + str(patch))

        # 3) AddEntryPointPatch
        # basically like AddCodePatch but we detour by changing oep
        # and we jump at the end of all of them
        # resolving symbols 
        if any([isinstance(p, AddEntryPointPatch) for p in patches]):
            pre_entrypoint_code_position = self.get_current_code_position()
            current_symbol_pos = self.get_current_code_position()
            entrypoint_patches = [p for p in patches if isinstance(p,AddEntryPointPatch)]
            between_restore_entrypoint_patches = sorted([p for p in entrypoint_patches if not p.after_restore], \
                key=lambda x:-1*x.priority)
            after_restore_entrypoint_patches = sorted([p for p in entrypoint_patches if p.after_restore], \
                key=lambda x:-1*x.priority)

            current_symbol_pos += len(utils.compile_asm_fake_symbol("pusha\n", current_symbol_pos))
            for patch in between_restore_entrypoint_patches:
                code_len = len(utils.compile_asm_fake_symbol(patch.asm_code, current_symbol_pos))
                if patch.name is not None:
                    self.name_map[patch.name] = current_symbol_pos
                current_symbol_pos += code_len
            # now compile for real
            new_code = utils.compile_asm(ASM_ENTRY_POINT_PUSH_ENV, self.get_current_code_position())
            self.added_code += new_code
            self.ncontent = utils.str_overwrite(self.ncontent, new_code)
            for patch in between_restore_entrypoint_patches:
                new_code = utils.compile_asm(patch.asm_code, self.get_current_code_position(), self.name_map)
                self.added_code += new_code
                self.added_patches.append(patch)
                l.info("Added patch: " + str(patch))
                self.ncontent = utils.str_overwrite(self.ncontent, new_code)

            restore_code = ASM_ENTRY_POINT_RESTORE_ENV
            current_symbol_pos += len(utils.compile_asm_fake_symbol(restore_code, current_symbol_pos))
            for patch in after_restore_entrypoint_patches:
                code_len = len(utils.compile_asm_fake_symbol(patch.asm_code, current_symbol_pos))
                if patch.name is not None:
                    self.name_map[patch.name] = current_symbol_pos
                current_symbol_pos += code_len
            # now compile for real
            new_code = utils.compile_asm(restore_code, self.get_current_code_position())
            self.added_code += new_code
            self.ncontent = utils.str_overwrite(self.ncontent, new_code)
            for patch in after_restore_entrypoint_patches:
                new_code = utils.compile_asm(patch.asm_code, self.get_current_code_position(), self.name_map)
                self.added_code += new_code
                self.ncontent = utils.str_overwrite(self.ncontent, new_code)
                self.added_patches.append(patch)
                l.info("Added patch: " + str(patch))

            oep = self.get_oep()
            self.set_oep(pre_entrypoint_code_position)
            new_code = utils.compile_jmp(self.get_current_code_position(),oep)
            self.added_code += new_code
            self.ncontent += new_code

        # 4) InlinePatch
        # we assume the patch never patches the added code
        for patch in patches:
            if isinstance(patch, InlinePatch):
                new_code = utils.compile_asm(patch.new_asm, patch.instruction_addr, self.name_map)
                assert len(new_code) == self.project.factory.block(patch.instruction_addr, num_inst=1).size
                file_offset = self.project.loader.main_object.addr_to_offset(patch.instruction_addr)
                self.ncontent = utils.str_overwrite(self.ncontent, new_code, file_offset)
                self.added_patches.append(patch)
                l.info("Added patch: " + str(patch))

        # 5) InsertCodePatch
        # these patches specify an address in some basic block, In general we will move the basic block
        # and fix relative offsets
        # With this backend heer we can fail applying a patch, in case, resolve dependencies
        insert_code_patches = [p for p in patches if isinstance(p, InsertCodePatch)]
        insert_code_patches = sorted([p for p in insert_code_patches],key=lambda x:-1*x.priority)
        applied_patches = []
        while True:
            name_list = [str(p) if (p==None or p.name==None) else p.name for p in applied_patches]
            l.info("applied_patches is: |" + "-".join(name_list)+"|")
            assert all([a == b for a,b in zip(applied_patches,insert_code_patches)])
            for patch in insert_code_patches[len(applied_patches):]:
                    self.save_state(applied_patches)
                    try:
                        l.info("Trying to add patch: " + str(patch))
                        new_code = self.insert_detour(patch)
                        self.added_code += new_code
                        self.ncontent = utils.str_overwrite(self.ncontent, new_code)
                        applied_patches.append(patch)
                        self.added_patches.append(patch)
                        l.info("Added patch: " + str(patch))
                    except (DetourException, MissingBlockException, DoubleDetourException) as e:
                        l.warning(e)
                        insert_code_patches, removed = self.handle_remove_patch(insert_code_patches,patch)
                        #print map(str,removed)
                        applied_patches = self.restore_state(applied_patches, removed)
                        l.warning("One patch failed, rolling back InsertCodePatch patches. Failed patch: "+str(patch))
                        break
                        # TODO: right now rollback goes back to 0 patches, we may want to go back less
                        # the solution is to save touched_bytes and ncontent indexed by applied patfch
                        # and go back to the biggest compatible list of patches
            else:
                break #at this point we applied everything in current insert_code_patches
                # TODO symbol name, for now no name_map for InsertCode patches

        header_patches = [InsertCodePatch,InlinePatch,AddEntryPointPatch,AddCodePatch, \
                AddRWDataPatch,AddRODataPatch,AddRWInitDataPatch]
        if any([isinstance(p,ins) for ins in header_patches for p in self.added_patches]) or \
                any([isinstance(p,SegmentHeaderPatch) for p in patches]) or self.pdf_removed:
            # either implicitly (because of a patch adding code or data) or explicitly, we need to change segment headers 

            # 6) SegmentHeaderPatch
            segment_header_patches = [p for p in patches if isinstance(p,SegmentHeaderPatch)]
            if len(segment_header_patches) > 1:
                msg = "more than one patch tries to change segment headers: " + "|".join([str(p) for p in segment_header_patches])
                raise IncompatiblePatchesException(msg)
            elif len(segment_header_patches) == 1:
                segment_patch = segment_header_patches[0]
                segments = segment_patch.segment_headers
                l.info("Added patch: " + str(segment_patch))
            else:
                segments = self.modded_segments

            for patch in [p for p in patches if isinstance(p,AddSegmentHeaderPatch)]:
                # add after the first
                segments = [segments[0]] + [patch.new_segment] + segments[1:]

            if not self.data_fallback:
                last_segment = segments[-1]
                p_type, p_offset, p_vaddr, p_paddr, p_filesz, p_memsz, p_flags, p_align = last_segment
                last_segment =  p_type, p_offset, p_vaddr, p_paddr, \
                       p_filesz, p_memsz + self.added_rwdata_len + self.added_rwinitdata_len, p_flags, p_align
                segments[-1] = last_segment
            self.setup_headers(segments)
            self.set_added_segment_headers(len(segments))
            l.debug("final symbol table: "+ repr([(k,hex(v)) for k,v in self.name_map.iteritems()]))
        else:
            l.info("no patches, the binary will not be touched")
 def set_oep(self, new_oep):
     # get original entry point
     self.ncontent = utils.str_overwrite(self.ncontent, struct.pack("<I", new_oep), 0x18)