コード例 #1
0
ファイル: binaryBuilder.py プロジェクト: ninishu/CCR
    def patchInitArray(self, secName, sectionChunk, fixups):
        """ Patches '.init_array' section """
        if fixups:
            initBar = util.ProgressBar(len(fixups))
            fixups = sorted(fixups, key=lambda F: F.offset)

        pos = fixups[0].offset if fixups else 0
        self.instBin += sectionChunk[:pos]
        patchCtr = 0

        # TODO: Check if init_array has relavant values in PIC/PIE
        if fixups:
            for i, FI in enumerate(fixups):
                if self.EI.isInReorderRange(FI.refTo):
                    newRefTo = self.EI.getBBlByVA(FI.refTo).newVA
                    self.instBin += struct.pack(self.getFormat(FI.derefSz), newRefTo)
                    patchCtr += 1
                    logging.debug("   [%s] Original: 0x%x -> Updated: 0x%x" %
                                 (secName, FI.refTo, newRefTo))
                else:
                    self.instBin += struct.pack(self.getFormat(FI.derefSz), FI.refTo)
                pos += FI.derefSz
                initBar += 1
            initBar.finish()

        self.instBin += sectionChunk[pos:]
        self.memo.numInitArrayPatch += patchCtr
コード例 #2
0
def write_clean_data_file(clean_data_file_name):
    """ Write clean data to the specified data file. """
    global data_to_write_queue
    global number_of_data_lines
    # Create a progress bar to show progress of processing.
    progress_bar = util.ProgressBar(number_of_data_lines.value,
                                    'Cleaning headers', 71)
    counter = 0
    with open(clean_data_file_name, FILE_WRITE_MODE) as data_file:
        while True:
            # Keep retrieving header data from write queue until timeout exception
            # is raised (should mean that no more data will be added to the queue).
            try:
                clean_data = data_to_write_queue.get(block=True,
                                                     timeout=QUEUE_TIMEOUT)
                data_file.write('{0}\n'.format(clean_data))
                counter += 1
                # Update progress bar.
                progress_bar.update(counter, total=number_of_data_lines.value)
            except Exception:
                # Finish up the writing process.
                progress_bar.clean()
                break
    # Return the number of entries that were written.
    return counter
コード例 #3
0
def write_to_data_file(data_file_name, num_files):
    """ Retrieves data maps from write queue and appends it to specified data file. """
    global data_to_write_queue
    progress_bar = util.ProgressBar(num_files, 'Extracting headers', 73)
    counter = 0
    with open(data_file_name, FILE_WRITE_MODE) as data_file:
        while True:
            # Keep retrieving header data from write queue until timeout exception 
            # is raised (should mean that no more data will be added to the queue).
            try:
                data = data_to_write_queue.get(block=True, timeout=QUEUE_TIMEOUT)
                stringified_data = util.stringify_headers(data)
                data_file.write('{0}\n'.format(stringified_data))
                counter += 1
                # Update progress bar.
                progress_bar.update(counter)
            except queue.Empty:
                # Finish up the writing process.
                progress_bar.clean()
                util.log_print('{0} entries written'.format(counter))
                break
            except Exception as error:
                print(error)
                util.log_print('{0} entries written'.format(counter))
                break
コード例 #4
0
ファイル: binaryBuilder.py プロジェクト: ninishu/CCR
    def patchSymbolTable(self, secName, sectionChunk):
        """ Patches all symbol values after randomization for '.dynsym' and '.symtab' section """
        symbolTable = self.EP.elf.get_section_by_name(secName)
        assert(len(sectionChunk) % symbolTable.num_symbols() == 0)
        symbolBar = util.ProgressBar(symbolTable.num_symbols())

        patchCtr = 0
        symOffset = 0
        randRangeBottom = self.EI._getTextSecVA() + self.EI.reorderObjStartFromText
        randRangeTop = randRangeBottom + self.EI.reorderedObjSize

        # [FIXME] Define a temporary bag to contain symval and symsize
        #         Ugly hack - it arises from the discrepancy of the function (symbol) size
        #                     that lacks alignment size; maybe we could do better
        sym_temp_lookup = {}
        for symbol in symbolTable.iter_symbols():
            symVal, symSz = symbol['st_value'], symbol['st_size']
            if randRangeBottom <= symVal < randRangeTop:
                sym_temp_lookup[symVal] = symSz

        for symbol in symbolTable.iter_symbols():
            """
            # The first 8 bytes have to be always identical - combined to symProperty
            type, bind = symbol['st_info']['type'], symbol['st_info']['bind']
            sym_other, sym_shndx = symbol['st_other']['visibility'], symbol['st_shndx']
            """
            symProperty = sectionChunk[symOffset:symOffset + 8]
            symVal, symSz = symbol['st_value'], symbol['st_size']
            self.instBin += symProperty

            # We only need to update the symbol value/size (either absolute or relative VA) here
            if randRangeBottom <= symVal < randRangeTop:
                try:
                    sym_fn = self.EI.getBBlByVA(symVal)
                    newSymVal = sym_fn.newVA
                    self.instBin += self.PK(FMT.LONG, newSymVal)
                    self.instBin += self.PK(FMT.LONG, sym_temp_lookup[symVal])

                    patchCtr += 1
                    logging.debug(" [%s] Original: 0x%x -> Updated: 0x%x" % (secName, symVal, newSymVal))

                    # [NEW] Let's save the symbol name for each function defined as a symbol
                    sym_fn.parent.name = symbol.name

                except AttributeError:
                    self.instBin += self.PK(FMT.LONG, symVal)
                    logging.warning(" [%s] Failed to update the symbol: 0x%x " % (secName, symVal))
            else:
                self.instBin += self.PK(FMT.LONG, symVal)
                self.instBin += self.PK(FMT.LONG, symSz)

            # Each entry for a symbol is 24B in size; Move on the next entry
            symOffset += 24
            symbolBar += 1

        symbolBar.finish()
        self.memo.numSymPatch += patchCtr
コード例 #5
0
ファイル: reorderInfo.py プロジェクト: quwenjie/x86-sok
    def __init__(self, RI):
        """
        Construct the essential information based on the collected information from compiler toolchain
            a) Build the layout tree - basic blocks, functions, and objects
            b) Build the entire fixup info (.text, .rodata, .data.rel.ro, .data and .init_array section)
            c) Confirm if reconstructed data is sane before processing randomization
        :param RI:
        """

        # Pre-processing: data collection and preparation for building essential information
        binInfo = RI['bin_info']
        # objInfo = (RI['obj_size'], RI['obj_func_cnt'], RI['obj_src_type'], RI['obj_offset'], RI['obj_section'])
        funcInfo = (RI['func_size'], RI['func_bb_cnt'], RI['func_offset'],
                    RI['func_section'], RI['func_type'])
        bbInfo = (RI['bb_size'], RI['bb_fixup_cnt'], RI['bb_fall_through'], RI['bb_offset'],  \
                RI['bb_section'], RI['bb_padding'], RI['bb_assemble'])

        fixupsText = (RI[C.DS_FIXUP_TEXT[0]], RI[C.DS_FIXUP_TEXT[1]],
                      RI[C.DS_FIXUP_TEXT[2]], RI[C.DS_FIXUP_TEXT[3]],
                      RI[C.DS_FIXUP_TEXT[4]], RI[C.DS_FIXUP_TEXT[5]],
                      RI[C.DS_FIXUP_TEXT[6]])
        fixupsRodata = (RI[C.DS_FIXUP_RODATA[0]], RI[C.DS_FIXUP_RODATA[1]],
                        RI[C.DS_FIXUP_RODATA[2]], RI[C.DS_FIXUP_RODATA[3]],
                        RI[C.DS_FIXUP_RODATA[4]])
        fixupsData = (RI[C.DS_FIXUP_DATA[0]], RI[C.DS_FIXUP_DATA[1]],
                      RI[C.DS_FIXUP_DATA[2]], RI[C.DS_FIXUP_DATA[3]],
                      RI[C.DS_FIXUP_DATA[4]])
        fixupsDataRel = (RI[C.DS_FIXUP_DATAREL[0]], RI[C.DS_FIXUP_DATAREL[1]],
                         RI[C.DS_FIXUP_DATAREL[2]], RI[C.DS_FIXUP_DATAREL[3]],
                         RI[C.DS_FIXUP_DATAREL[4]])
        fixupsInitArray = (RI[C.DS_FIXUP_INIT_ARR[0]],
                           RI[C.DS_FIXUP_INIT_ARR[1]],
                           RI[C.DS_FIXUP_INIT_ARR[2]],
                           RI[C.DS_FIXUP_INIT_ARR[3]],
                           RI[C.DS_FIXUP_INIT_ARR[4]])

        layoutInfoCnt = len(RI['func_size']) + len(RI['bb_size'])
        self.fixupInfoCnt = len(RI[C.DS_FIXUP_TEXT[0]]) + len(RI[C.DS_FIXUP_RODATA[0]]) + \
                            len(RI[C.DS_FIXUP_DATA[0]]) + len(RI[C.DS_FIXUP_DATAREL[0]]) + \
                            len(RI[C.DS_FIXUP_INIT_ARR[0]])
        bar = util.ProgressBar(layoutInfoCnt + self.fixupInfoCnt)

        # a) Construct BasicBlocks, Functions, Objects and the binary in Bottom-Up way
        self.constructInfo = BasicBlocks(bar, binInfo, funcInfo, bbInfo)
        # self.constructInfo.storeAlignSize(RI['align_size'])

        # b) Construct all fixups in .text, .rodata, .data, data.rel.ro, and .init_array
        self.FixupsInText, self.FixupsInRodata = None, None
        self.FixupsInData, self.FixupsInDataRel = None, None
        self.FixupsInInitArray, self.FixupsInEhframe = None, None
        self.processFixups(bar, RI, fixupsText, fixupsRodata, fixupsData,
                           fixupsDataRel, fixupsInitArray)
        bar.finish()
コード例 #6
0
def validator(request):
    problem_dir = RUN_DIR / 'test/problems/identity'
    os.chdir(problem_dir)

    h = hashlib.sha256(bytes(Path().cwd())).hexdigest()[-6:]
    tmpdir = Path(tempfile.gettempdir()) / ('bapctools_' + h)
    tmpdir.mkdir(exist_ok=True)
    p = problem.Problem(Path('.'), tmpdir)
    validator = validate.OutputValidator(p, RUN_DIR  / 'support' / request.param)
    print(util.ProgressBar.current_bar)
    bar = util.ProgressBar('build', max_len=1)
    validator.build(bar)
    bar.finalize()
    yield (p, validator)
    os.chdir(RUN_DIR)
コード例 #7
0
    def train(self):
        """
        Train the autoencoder.
        """
        training_data = torchvision.datasets.ImageFolder(
            self.in_path,
            torchvision.transforms.Compose([torchvision.transforms.ToTensor()
                                            ]))

        data_loader = torch.utils.data.DataLoader(training_data,
                                                  batch_size=self.batch_size,
                                                  shuffle=True,
                                                  num_workers=2)

        criterion = torch.nn.L1Loss().cuda(
        ) if self.use_cuda else torch.nn.L1Loss()

        progress_bar = util.ProgressBar()

        if not os.path.exists('saves/'):
            os.makedirs('saves/')

        for epoch in range(self.start_epoch, self.num_epochs + 1):
            print('Epoch {}/{}'.format(epoch, self.num_epochs))
            for i, data in enumerate(data_loader, 1):
                x, _ = data
                x = x.to(self.device)

                output = self.net(x)

                loss = criterion(output, x)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                self.losses.append(loss)

                if self.verbose and (i % 10 == 0 or i == len(data_loader) - 1):
                    info_str = 'loss: {:.4f}'.format(self.losses[-1])
                    progress_bar.update(max_value=len(data_loader),
                                        current_value=i + 1,
                                        info=info_str)

            progress_bar.new_line()

            self.save(epoch=epoch,
                      path='saves/' + self.name + '_' + str(epoch) + '.pth')
コード例 #8
0
ファイル: binaryBuilder.py プロジェクト: ninishu/CCR
    def patchDataSection(self, sectionChunk, fixups):
        """
        Patches all collected fixups in .rodata/.data/.data.rel.ro section
        :param sectionChunk:
        :param fixups:
        :return:
        """

        # Step 1) Check if any fixupData exists
        if fixups:
            dataBar = util.ProgressBar(len(fixups))
            fixups = sorted(fixups, key=lambda F: F.offset)

        # Step 2) Initialize the first fixup offset in the section
        pos = fixups[0].offset if fixups else 0

        # Step 3) Copy all data until the first fixup appears if any
        self.instBin += sectionChunk[:pos]

        # Step 4) The fixups to be patched might not be continuous
        #         Hence take care of non-fixup data during updates
        if fixups:
            for i, FI in enumerate(fixups):
                # newRefVals are already computed from performReorder()
                refSz, newRefVal = FI.derefSz, FI.newRefVal
                self.instBin += struct.pack(self.getFormat(refSz), newRefVal)
                pos += refSz
                if i < len(fixups) - 1 and fixups[i + 1].offset > pos:
                    nextFixupOffset = fixups[i + 1].offset
                    self.instBin += sectionChunk[pos:nextFixupOffset]
                    pos = nextFixupOffset
                dataBar += 1
            dataBar.finish()

        # Step 5) The last remaining data to maintain intact if any
        self.instBin += sectionChunk[pos:]
コード例 #9
0
ファイル: binaryBuilder.py プロジェクト: ninishu/CCR
    def patchCodeSection(self, sectionChunk):
        """
        Patches all collected fixups in .text section
        :param sectionChunk:
        :return:
        """
        textBar = util.ProgressBar(len(self.randomizedBBContainer))

        # Fixme: Would be better if fixup could be updated at updateFixupRefs() of reorderEngine
        def updateRemainingFixups(FI, kind):
            if FI.refBB:
                self.instBin += textSection[pos:FI.offset]
                if not FI.isRela:  # Absolute value
                    FI.newRefVal = FI.newRefTo = FI.refBB.newVA
                else:  # Relative value
                    FI.newRefTo = FI.refBB.newVA
                    FI.newRefVal = FI.newRefTo - FI.VA - FI.derefSz
                self.instBin += struct.pack(self.getFormat(FI.derefSz), FI.newRefVal)
                logging.debug("[%s Fixup] 0x%x->0x%x@0x%x (BBVA: 0x%x->0x%x)" %
                             (kind, FI.derefVal, FI.newRefVal, FI.VA, FI.refBB.VA, FI.refBB.newVA))

            else:
                self.instBin += textSection[pos:FI.offset + FI.derefSz]
                logging.debug("[%s Fixup] Not updated: 0x%x (%s)" % (kind, FI.derefVal, FI.target))

        # Tricky - elf data() might be missing padding bytes; so use code section chuck instead of it
        # textSection = self.EP.elf.get_section_by_name('.text').data()
        textSection = sectionChunk
        MAINOFFSZ = SZ[FMT.INT]
        mainOffset = self.EI.mainAddrOffsetFromText

        # No-PIC case: update special fixups and main() address
        if mainOffset > 0:
            # Step 1-1) Copy everything before the pointer of the main function at _start() in crti.o object
            # The following procedure handles the fixups in a special section names if any, including
            #     {".text.unlikely", ".text.exit", ".text.startup", ".text.hot"} for compatibility reason
            #     Mostly class initialization routine resides in these sections in c++ applications
            if self.fixupsSpecial:
                pos = 0
                # Update all fixups that refer to the BBL within the randomization range
                for FI in self.fixupsSpecial:
                    updateRemainingFixups(FI, 'Special')
                    pos = FI.offset + FI.derefSz

                # Copy all bytes until we meet mainOffset to be updated
                self.instBin += textSection[pos:mainOffset]

            # Otherwise just copy all bytes to update mainOffset
            else:
                self.instBin += textSection[:mainOffset]

            # Step 2-1) Update the location pointing to the main() after reordering if any
            # TODO: A (dirty) hacky solution when having different _start implementation of CRT (i.e., crt1.o)
            '''
             In Utuntu 18.04 (or maybe other systems), the user-defined main() address at _start 
             stores its value to a register in a rip-relative way as follows (*); 
             for example; %rdi holds the main() hence it needs to be adjusted
              <_start> in Ubuntu 16.04:
                   d:	50                   	push   %rax
                   e:	54                   	push   %rsp
                   f:	49 c7 c0 00 00 00 00 	mov    $0x0,%r8
                  16:	48 c7 c1 00 00 00 00 	mov    $0x0,%rcx
                  1d:	48 c7 c7 00 00 00 00 	mov    $0x0,%rdi (*)
                  24:	e8 00 00 00 00       	callq  29 <_start+0x29>
                  
              <_start> in Ubuntu 18.04:
                   d:   50                      push   %rax
                   e:   54                      push   %rsp
                   f:   4c 8b 05 00 00 00 00    mov    0x0(%rip),%r8     # 16 <_start+0x16>
                  16:   48 8b 0d 00 00 00 00    mov    0x0(%rip),%rcx    # 1d <_start+0x1d>
                  1d:   48 8b 3d 00 00 00 00    mov    0x0(%rip),%rdi    # 24 <_start+0x24> (*)
                  24:   ff 15 00 00 00 00       callq  *0x0(%rip)        # 2a <_start+0x2a>
                  
              As we do not collect fixups for non-user defined functions, just update it here accordingly
              Still using disassembler is undesirable, thus do the check (0x3d), which looks too hacky.... 
            '''
            mainAddr = self.UPK(FMT.INT, textSection[mainOffset: mainOffset + MAINOFFSZ])
            is_main_rip_relative = ord(textSection[mainOffset - 1]) == 0x3d
            if is_main_rip_relative:
                adjust_val = self.EI.entryPoint + mainOffset + MAINOFFSZ
                mainAddr =  mainAddr + adjust_val

            self.memo.origMainAddr = mainAddr

            if self.EI.base > 0: # Absolute address (w/o PIC/PIE option)
                mainBBL = self.EI.getBBlByVA(mainAddr)

                # Here the pointer to the new main() should be adjusted in case of a "rip-relative" mov
                if is_main_rip_relative:
                    self.instBin += self.PK(FMT.INT, mainBBL.newVA - adjust_val)
                else:
                    self.instBin += self.PK(FMT.INT, mainBBL.newVA)
                self.memo.instMainAddr = mainBBL.newVA
            else:                # Relative address (w/ PIC/PIE option)
                mainBBL = self.EI.getBBlByVA(self.EI._getTextSecVA() + mainAddr + self.EI.mainAddrOffsetFromText + MAINOFFSZ)
                mainNewOffset = mainBBL.newVA - (self.EI.mainAddrOffsetFromText + MAINOFFSZ + self.EI._getTextSecVA())
                self.instBin += self.PK(FMT.INT, mainNewOffset)
                self.memo.instMainAddr = mainNewOffset

            # Step 3-1) Copy the bytes right before the first object to be reordered
            self.instBin += textSection[mainOffset + MAINOFFSZ:self.EI.reorderObjStartFromText]

        # When there is no main function (i.e, shared object w/ [-fPIC -pie] option),
        # just copy all until the first object for randomization
        else:
            self.memo.isMain = False

            # Step 1-2) Do the same update as above iff there is any fixup in a special section
            if self.fixupsSpecial:
                pos = 0
                for FI in self.fixupsSpecial:
                    # Here needs to care only if the fixup has a referenced BBL within the randomization range
                    updateRemainingFixups(FI, 'Special')
                    pos = FI.offset + FI.derefSz

                # Copy all bytes until we meet mainOffset to be updated
                self.instBin += textSection[pos:mainOffset]

            # Step 2-2) Otherwise just copy all bytes to update mainOffset
            else:
                self.instBin += textSection[:self.EI.reorderObjStartFromText]

        # Step 4) Patch the references accordingly
        #         Append the reordered basic blocks in order
        textPos = self.EI.reorderObjStartFromText
        for BBL in self.randomizedBBContainer:
            secOff = BBL.offsetFromSection
            BBLcode = textSection[secOff:secOff + BBL.size]
            fixupBBLOffs = dict()

            # Obtain patching locations (distance from each BBL) for all fixups
            for FI in BBL.Fixups:
                fixupBBLOffs[FI.VA - BBL.VA] = (FI.derefSz, FI.newRefVal)

            # Copy the patched code in a randomized order
            pos = 0
            for off in sorted(fixupBBLOffs.keys()):
                # Code part to preserve
                self.instBin += BBLcode[pos:off]
                # Fixup reference to be updated (already from performTransformation())
                refSz, refVal = fixupBBLOffs[off]
                try:
                    self.instBin += struct.pack(self.getFormat(refSz), refVal)
                except struct.error:
                    logging.critical("BBL#%d [BBLOff:%d] %s" % (BBL.idx, off, BBL))
                    logging.critical("\tStruct error during patching references in .text! " \
                                     "(Deref Size: %dB, Val:0x%04x)" % (refSz, refVal))
                    sys.exit(1)
                pos = off + refSz

            self.instBin += BBLcode[pos:]  # Remaining code in the BBL if any
            textPos += BBL.size
            textBar += 1

        textBar.finish()

        # Step 5) (Optional) This is only when there are orphan fixups
        #          only when for -cfi-icall option has been enabled (implementation-specific)
        #          Other CFI options for LLVM do not contain any orphan fixup observed yet.
        '''
            1) cfi_icall:          CFI for indirect calls
            2) cfi_vcall:          CFI for virtual function calls
            3) cfi_nvcall:         CFI for calling non-virtual member functions
            4) cfi_unrelated_cast: CFI for the casts between objects of unrelated types
            5) cfi_derived_cast:   CFI for the casts between a base and a derived class
            6) cfi_cast_strict:    specific instance where 5) would not catch an illegal cast
        '''
        if self.fixupsOrphan:
            for FI in self.fixupsOrphan:
                # Assume orphan fixups for CFI always should have a refBB with a relative value
                try:
                    trampoline = '\xE9'     # unconditional jump instruction
                    FI.newRefTo = FI.refBB.newVA
                    FI.newRefVal = FI.newRefTo - FI.VA - FI.derefSz
                    # 8 byte aligned with three 0xCCs at all times
                    trampoline += struct.pack(self.getFormat(FI.derefSz), FI.newRefVal) + '\xCC'*3
                    self.instBin += trampoline
                    logging.info("[Orphan Fixup] 0x%x->0x%x@0x%x (BBVA: 0x%x->0x%x)" %
                             (FI.derefVal, FI.newRefVal, FI.VA, FI.refBB.VA, FI.refBB.newVA))
                except:
                    logging.critical("[Orphan Fixup] Could not find a reference BB!: %s" % (FI))
            textPos += 8 * len(self.fixupsOrphan)

        # Step 5) Remaining code in the text section (that has not been reordered)
        self.instBin += textSection[textPos:]
コード例 #10
0
ファイル: join_views.py プロジェクト: fendaq/GL_BD_LSTM
    ts_str = util.str.find_all(path, "\d+\.\d+")[0]
    return float(ts_str)


image_dict = defaultdict(list)
for view_idx, view_dir in enumerate(view_dirs):
    image_names = util.io.ls(view_dir, ".jpg")
    for image_name in image_names:
        ts = get_ts(image_name)
        image_dict[ts].append(view_dir + "/" + image_name)

timestamps = image_dict.keys()
timestamps.sort()
view_names = [util.io.get_filename(name) for name in view_dirs]
output_dir = "~/temp/no-use/" + util.str.join(view_names, "+")
bar = util.ProgressBar(len(image_dict))
for ts in timestamps:
    bar.move(1)
    images = []
    output_path = util.io.join_path(output_dir,
                                    str(ts) + ".jpg")
    if util.io.exists(output_path):
        continue
    for image_path in image_dict[ts]:
        image_name = util.io.get_filename(image_path)
        image = util.img.imread(image_path, rgb=True)
        images.append(image)
    image_data = np.concatenate(images, axis=1)
    util.img.imwrite(output_path, image_data)
    #util.plt.show_images(images = images, titles = view_names, save = True, show = True, path = image_path, axis_off = True)
コード例 #11
0
def join_views():
    image_dict = defaultdict(list)
    timestamps = set()
    planning_dict = {}
    num_views = len(view_dirs)
    for view_idx, view_dir in enumerate(view_dirs):
        view_dirs[view_idx] = util.io.get_absolute_path(view_dir)
        image_names = util.io.ls(view_dir, ".jpg")
        for image_name in image_names:
            ts = get_ts(image_name)
            if util.str.contains(image_name, "Planning", ignore_case=True):
                planning_dict[ts] = view_dir + "/" + image_name
                continue
            if util.str.contains(image_name, "Fusion"):
                timestamps.add(ts)
            ts = get_ts(image_name)
            image_dict[ts].append(view_dir + "/" + image_name)
    planning_timestamps = list(planning_dict.keys())
    planning_timestamps.sort()
    if planning_timestamps:
        num_views -= 1
    timestamps = list(timestamps)
    timestamps.sort()
    view_names = [util.io.get_filename(name) for name in view_dirs]
    output_dir = util.io.join_path("~/temp/no-use/", \
                                                util.str.replace_all(view_dirs[0], "/", "_") + util.str.join(view_names, "+"))
    bar = util.ProgressBar(len(image_dict))
    for ts in timestamps:
        bar.move(1)
        images = []
        camera_images = []
        output_path = util.io.join_path(
            output_dir,
            str(get_frame(image_dict[ts][0])) + "_" + str(ts) + ".jpg")
        if len(image_dict[ts]) < num_views:
            continue
        if util.io.exists(output_path):
            continue
        for image_path in image_dict[ts]:
            image_name = util.io.get_filename(image_path)
            image = util.img.imread(image_path, rgb=False)
            if util.str.contains(image_name, 'Camera'):
                camera_images.append(image)
            else:
                images.append(image)
        min_diff = 1000
        min_idx = -1
        for pidx, pts in enumerate(planning_timestamps):
            diff = abs(pts - ts)
            if diff < min_diff:
                min_idx = pidx
                min_diff = diff

        if min_diff < TIME_DIFF_TH:
            images.append(
                get_planning_img(planning_dict[planning_timestamps[min_idx]],
                                 images[0].shape[:-1], min_idx))

        image_data = np.concatenate(images, axis=1)
        if camera_images:
            h, w = camera_images[0].shape[:-1]
            camera_width = images[0].shape[1]
            if len(camera_images) == 1:
                camera_shape = (image_data.shape[0], camera_width, 3)
                camera_data = np.zeros(camera_shape, dtype=np.uint8)
                camera_height = int(h * (camera_width * 1.0 / w))
                ci = camera_images[0]
                ci = util.img.resize(ci, (camera_width, camera_height))
                camera_data[:ci.shape[0], :ci.shape[1], :] = ci
            else:
                camera_data = np.concatenate(camera_images, axis=0)
                camera_data = util.img.resize(
                    camera_data, (camera_width, image_data.shape[0]))
    #         image_height = max([camera_data.shape[0], image_data.shape[0]])
    #         camera_data = util.img.resize(camera_data, (camera_data.shape[1], image_height))
    #         image_data = util.img.resize(image_data, (image_data.shape[1], image_height))
            image_data = np.concatenate([camera_data, image_data], axis=1)
        util.img.imwrite(output_path, image_data)
コード例 #12
0
    def train(self):
        """
        Trains the architecture.
        """
        training_data_photo = torchvision.datasets.ImageFolder(
            self.in_path_photo,
            torchvision.transforms.Compose([torchvision.transforms.ToTensor()
                                            ]))

        data_loader_photo = torch.utils.data.DataLoader(
            training_data_photo,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=2)

        training_data_oil = torchvision.datasets.ImageFolder(
            self.in_path_oil,
            torchvision.transforms.Compose([torchvision.transforms.ToTensor()
                                            ]))

        data_loader_oil = torch.utils.data.DataLoader(
            training_data_oil,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=2)

        l1_criterion = torch.nn.L1Loss().cuda(
        ) if self.use_cuda else torch.nn.L1Loss()

        progress_bar = util.ProgressBar()

        for epoch in range(self.start_epoch, self.num_epochs + 1):
            print('Epoch {}/{}'.format(epoch, self.num_epochs))
            for i, (photo_batch, oil_batch) in enumerate(
                    zip(data_loader_photo, data_loader_oil), 1):
                x_photo, _ = photo_batch
                x_oil, _ = oil_batch
                x_photo = x_photo.to(self.device)
                x_oil = x_oil.to(self.device)

                if x_photo.shape[0] != x_oil.shape[0]:
                    continue

                z_oil, x_rec_photo = self._train_photo_to_oil_and_back(
                    x_photo, l1_criterion)

                v = self.auto_photo.encode(x_photo)
                u_dash = self.map_v_to_u(v)
                z_oil = self.auto_oil.decode(u_dash)

                self._update_discriminator(self.discriminator_oil,
                                           self.d_oil_optimizer,
                                           x_real=x_oil,
                                           x_fake=z_oil,
                                           losses=self.d_oil_losses)

                z_photo, x_rec_oil = self._train_oil_to_photo_and_back(
                    x_oil, l1_criterion)

                u = self.auto_oil.encode(x_oil)
                v_dash = self.map_u_to_v(u)
                z_photo = self.auto_photo.decode(v_dash)

                self._update_discriminator(self.discriminator_photo,
                                           self.d_photo_optimizer,
                                           x_real=x_photo,
                                           x_fake=z_photo,
                                           losses=self.d_photo_losses)

                if self.verbose:
                    info_str = 'd_photo: {:.4f}, d_oil: {:.4f}, g_photo: {:.4f}, g_oil: {:.4f}'.\
                               format(self.d_photo_losses[-1], self.d_oil_losses[-1], self.g_photo_losses[-1],
                                      self.g_oil_losses[-1])

                    progress_bar.update(max_value=len(data_loader_oil),
                                        current_value=i + 1,
                                        info=info_str)

            if not os.path.exists('saves/'):
                os.makedirs('saves/')
            self.save(epoch, path='saves')

            progress_bar.new_line()
コード例 #13
0
    def train(self,
              save_dir,
              num_epochs=75,
              batch_size=256,
              learning_rate=0.001,
              test_each_epoch=False,
              verbose=False):
        """Trains the network.

        Parameters
        ----------
        save_dir : str
            The directory in which the parameters will be saved
        num_epochs : int
            The number of epochs
        batch_size : int
            The batch size
        learning_rate : float
            The learning rate
        test_each_epoch : boolean
            True: Test the network after every training epoch, False: no testing
        verbose : boolean
            True: Print training progress to console, False: silent mode
        """
        self.optimizer = torch.optim.Adam(self.net.parameters(),
                                          lr=learning_rate,
                                          weight_decay=1e-5)
        self.net.train()

        train_transform = transforms.Compose([
            util.Cutout(num_cutouts=2, size=8, p=0.8),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = datasets.CIFAR10('data/cifar',
                                         train=True,
                                         download=True,
                                         transform=train_transform)
        data_loader = torch.utils.data.DataLoader(train_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=True)
        criterion = torch.nn.CrossEntropyLoss().cuda(
        ) if self.use_cuda else torch.nn.CrossEntropyLoss()

        progress_bar = util.ProgressBar()

        for epoch in range(self.start_epoch, num_epochs + 1):
            print('Epoch {}/{}'.format(epoch, num_epochs))

            epoch_correct = 0
            epoch_total = 0
            for i, data in enumerate(data_loader, 1):
                images, labels = data
                images = images.to(self.device)
                labels = labels.to(self.device)

                self.optimizer.zero_grad()
                outputs = self.net.forward(images)
                loss = criterion(outputs, labels.squeeze_())
                loss.backward()
                self.optimizer.step()

                _, predicted = torch.max(outputs.data, dim=1)
                batch_total = labels.size(0)
                batch_correct = (predicted == labels.flatten()).sum().item()

                epoch_total += batch_total
                epoch_correct += batch_correct

                if verbose:
                    # Update progress bar in console
                    info_str = 'Last batch accuracy: {:.4f} - Running epoch accuracy {:.4f}'.\
                                format(batch_correct / batch_total, epoch_correct / epoch_total)
                    progress_bar.update(max_value=len(data_loader),
                                        current_value=i,
                                        info=info_str)

            self.train_accuracies.append(epoch_correct / epoch_total)
            if verbose:
                progress_bar.new_line()

            if test_each_epoch:
                test_accuracy = self.test()
                self.test_accuracies.append(test_accuracy)
                if verbose:
                    print('Test accuracy: {}'.format(test_accuracy))

            # Save parameters after every epoch
            self.save_parameters(epoch, directory=save_dir)
コード例 #14
0
def general_image_folder(opt):
    """Create lmdb for general image folders
    Users should define the keys, such as: '0321_s035' for DIV2K sub-images
    If all the images have the same resolution, it will only store one copy of resolution info.
        Otherwise, it will store every resolution info.
    """
    #### configurations
    read_all_imgs = False  # whether real all images to memory with multiprocessing
    # Set False for use limited memory
    BATCH = 5000  # After BATCH images, lmdb commits, if read_all_imgs = False
    n_thread = 40
    ########################################################
    img_folder = opt['img_folder']
    lmdb_save_path = opt['lmdb_save_path']
    meta_info = {'name': opt['name']}
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    all_img_list = sorted(glob.glob(osp.join(img_folder, '*')))
    keys = []
    for img_path in all_img_list:
        keys.append(osp.splitext(osp.basename(img_path))[0])

    if read_all_imgs:
        #### read all images to memory (multiprocessing)
        dataset = {
        }  # store all image data. list cannot keep the order, use dict
        print('Read images with multiprocessing, #thread: {} ...'.format(
            n_thread))
        pbar = util.ProgressBar(len(all_img_list))

        def mycallback(arg):
            '''get the image data and update pbar'''
            key = arg[0]
            dataset[key] = arg[1]
            pbar.update('Reading {}'.format(key))

        pool = Pool(n_thread)
        for path, key in zip(all_img_list, keys):
            pool.apply_async(read_image_worker,
                             args=(path, key),
                             callback=mycallback)
        pool.close()
        pool.join()
        print('Finish reading {} images.\nWrite lmdb...'.format(
            len(all_img_list)))

    #### create lmdb environment
    data_size_per_img = cv2.imread(all_img_list[0],
                                   cv2.IMREAD_UNCHANGED).nbytes
    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))
    txn = env.begin(write=True)
    resolutions = []
    for idx, (path, key) in enumerate(zip(all_img_list, keys)):
        pbar.update('Write {}'.format(key))
        key_byte = key.encode('ascii')
        data = dataset[key] if read_all_imgs else cv2.imread(
            path, cv2.IMREAD_UNCHANGED)
        if data.ndim == 2:
            H, W = data.shape
            C = 1
        else:
            H, W, C = data.shape
        txn.put(key_byte, data)
        resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W))
        if not read_all_imgs and idx % BATCH == 0:
            txn.commit()
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information
    # check whether all the images are the same size
    assert len(keys) == len(resolutions)
    if len(set(resolutions)) <= 1:
        meta_info['resolution'] = [resolutions[0]]
        meta_info['keys'] = keys
        print('All images have the same resolution. Simplify the meta info.')
    else:
        meta_info['resolution'] = resolutions
        meta_info['keys'] = keys
        print(
            'Not all images have the same resolution. Save meta info for each image.'
        )

    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
コード例 #15
0
import util
image_dir = util.io.get_absolute_path(util.argv[1])
target_dir = "/data/fusion/Planning_rename"
if len(util.argv) > 2 :
    target_dir = util.argv[2]
    
image_names = util.io.ls(image_dir, ".jpg")

def get_ts(path) :
    ts_str = util.str.find_all(path, "\d+\.\d+")[0]
    return float(ts_str)

image_names.sort()

pb = util.ProgressBar(len(image_names))
for idx, name in enumerate(image_names):
    src_path = util.io.join_path(image_dir, name)
    ts = get_ts(name);
    new_name = str(idx) + "_Planning_" + util.time.timestamp2str(ts) + "_" + str(ts)+ ".jpg" 
    target_path = util.io.join_path(target_dir, new_name)
    pb.move(1);
    if util.io.exists(target_path):
        continue;
    image_data = util.img.imread(src_path)
    h, w = image_data.shape[:-1]
    pos = (0, int(h * 0.4))
    util.img.put_text(image_data, new_name, pos, 0.5, util.img.COLOR_BGR_RED, 1)
    util.img.imwrite(target_path, image_data);
コード例 #16
0
ファイル: system.py プロジェクト: cdelgehier/fabric_cmdline
    def piped_tar(self,src_host,src_path,dst_path,src_user=env.user):
        """
        Copy object from [src_user]@<src_host>:<src_path> to host into <dst_path>
        src_user is the default ssh client user
        """
        import paramiko
        if util._is_host_up(env.host, int(env.port)) is False:
            return False
        if util._is_host_up(src_host, int(env.port)) is False:
            return False
        if src_path.endswith('/'):
            src_path = re.sub("/$","",src_path)
        src_path_last_dir = re.sub(r"^(.*)/([^\/]+)$","\\1",src_path)
        src_path_end = re.sub(r"^"+src_path_last_dir+"/([^\/]+)$","\\1",src_path)
        save_host = env.host
        save_user = env.user
        save_host_string = env.host_string
        env.host = src_host
        env.host_string = src_host
        env.user= src_user
        check_src = run("ls "+src_path)
        if check_src.failed:
            print(red("FAILURE")+" cannot found <"+src_path+"> on <"+env.host+">")
            return False
        env.host = save_host
        env.user = save_user
        env.host_string = save_host_string
        check_src = run("ls "+dst_path)
        if check_src.failed:
            print(red("FAILURE")+" cannot found <"+dst_path+"> on <"+env.host+">")
            return False
        if not dst_path.endswith('/'):
            dst_path += '/'
        command = "ssh "+src_user+'@'+src_host+" \"cd "+src_path_last_dir+" && tar zcvf - "+src_path_end+"\" | ssh "+env.user+"@"+env.host+" \"cd "+dst_path+" && tar zxvf -\""
        puts(command)
        save_host = env.host
        save_host_string = env.host_string
        env.host_string = src_host
        env.host = src_host
        orig_size_file = self._fabrun("stat -c %s "+src_path)
        orig_size_file_list = orig_size_file.splitlines()
        orig_size_file = 0
        for f in orig_size_file_list:
            try:
                int(f)
                orig_size_file+=int(f)
            except:
                orig_size_file+=0
        if int(orig_size_file) == 0:
            print red("Error: ")+"<"+src_path+"> file size is null"
            return False
        puts("Orig file size = "+str(orig_size_file))
        env.host = save_host
        env.host_string = save_host_string

        newpid = os.fork()

        if newpid == 0:
            # child

            puts("child "+str(os.getpid()))

            Bar = util.ProgressBar(int(orig_size_file), 60, src_host+'->'+env.host)
            client = paramiko.SSHClient()
            client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            client.connect(env.host, username=env.user)
            while True:
                _, stdout, _ = client.exec_command("stat -c %s "+dst_path+src_path_end+" 2>/dev/null")
                size_file = stdout.read()
                puts("found size = "+size_file)
                size_file_list = size_file.splitlines()
                size_file = 0
                for s in size_file_list:
                    try:
                        int(s)
                        size_file+=int(s)
                    except:
                        size_file+=0
                puts("Dst file size = "+str(size_file))

                Bar.update(int(size_file))

                if int(size_file) >= int(orig_size_file):
                    print " copy terminated"
                    os._exit(0)

        else:
            # parent
            result = local(command,capture=True)
            os.waitpid(newpid, 0)

            if not result.failed:
                print(green("SUCCESS")+" during piped tar through ssh copy")
                puts(result.stderr)
            else:
                print(red("FAILURE")+" during piped tar through ssh copy")
                print result.stderr