Exemple #1
0
def subtract_bg(fn, opt: PreProcOpts):
    """Subtracts the background of a diffraction pattern by azimuthal integration excluding the Bragg peaks.
    
    Arguments:
        fn {function} -- [description]
        opt {PreProcOpts} -- [description]
    
    Returns:
        [type] -- [description]
    """

    if isinstance(fn, list) and len(fn) == 1:
        fn = fn[0]

    def log(*args):
        if not (opt.verbose or any([isinstance(err, Exception)
                                    for e in args])):
            return
        if isinstance(fn, list):
            dispfn = os.path.basename(fn[0]) + ' etc.'
        else:
            dispfn = os.path.basename(fn)
        idstring = '[{} - {} - subtract_bg] '.format(
            datetime.datetime.now().time(), dispfn)
        print(idstring, *args)

    ds = Dataset().from_list(fn)
    ds.open_stacks(readonly=False)

    if opt.rerun_peak_finder:
        pks = find_peaks(ds, opt=opt)
        nPeaks = da.from_array(pks['nPeaks'][:, np.newaxis, np.newaxis],
                               chunks=(ds.centered.chunks[0], 1, 1))
        peakX = da.from_array(pks['peakXPosRaw'][:, :, np.newaxis],
                              chunks=(ds.centered.chunks[0], -1, 1))
        peakY = da.from_array(pks['peakYPosRaw'][:, :, np.newaxis],
                              chunks=(ds.centered.chunks[0], -1, 1))
    else:
        nPeaks = ds.nPeaks[:, np.newaxis, np.newaxis]
        peakX = ds.peakXPosRaw[:, :, np.newaxis]
        peakY = ds.peakYPosRaw[:, :, np.newaxis]

    original = ds.centered
    bg_corrected = da.map_blocks(proc2d.remove_background,
                                 original,
                                 original.shape[2] / 2,
                                 original.shape[1] / 2,
                                 nPeaks,
                                 peakX,
                                 peakY,
                                 peak_radius=opt.peak_radius,
                                 filter_len=opt.filter_len,
                                 dtype=np.float32 if opt.float else np.int32,
                                 chunks=original.chunks)

    ds.add_stack('centered', bg_corrected, overwrite=True)
    ds.change_filenames(opt.nobg_file_suffix)
    ds.init_files(keep_features=False, overwrite=True)
    ds.store_tables(shots=True, features=True)
    ds.open_stacks(readonly=False)

    # for lbl in ['nPeaks', 'peakTotalIntensity', 'peakXPosRaw', 'peakYPosRaw']:
    #    if lbl in ds.stacks:
    #        ds.delete_stack(lbl, from_files=False)

    try:
        ds.store_stacks(overwrite=True, progress_bar=False)
    except Exception as err:
        log('Error during background correction:', err)
        raise err
    finally:
        ds.close_stacks()

    return ds.files
Exemple #2
0
def cumulate(fn, opt: PreProcOpts):
    """Applies cumulative summation to a data set comprising movie frame stacks. At the moment, requires
    the summed frame stacks to have the same shape as the raw data.
    
    Arguments:
        fn {function} -- [description]
        opt {PreProcOpts} -- [description]
    
    Raises:
        err: [description]
    
    Returns:
        [type] -- [description]
    """

    if isinstance(fn, list) and len(fn) == 1:
        fn = fn[0]

    def log(*args):
        if isinstance(fn, list):
            dispfn = os.path.basename(fn[0]) + ' etc.'
        else:
            dispfn = os.path.basename(fn)
        idstring = '[{} - {} - cumulate] '.format(
            datetime.datetime.now().time(), dispfn)
        print(idstring, *args)

    dssel = Dataset().from_list(fn)
    log('Cumulating from frame', opt.cum_first_frame)
    dssel.open_stacks(readonly=False)

    # chunks for aggregation
    chunks = tuple(
        dssel.shots.groupby(opt.idfields).count()['selected'].values)
    for k, stk in dssel.stacks.items():
        if stk.chunks[0] != chunks:
            if k == 'index':
                continue
            log(k, 'needs rechunking...')
            dssel.add_stack(k, stk.rechunk({0: chunks}), overwrite=True)
    dssel._zchunks = chunks

    def cumfunc(movie):
        movie_out = movie
        movie_out[opt.cum_first_frame:,
                  ...] = np.cumsum(movie[opt.cum_first_frame:, ...], axis=0)
        return movie_out

    for k in opt.cum_stacks:
        dssel.stacks[k] = dssel.stacks[k].map_blocks(
            cumfunc, dtype=dssel.stacks[k].dtype)

    dssel.change_filenames(opt.cum_file_suffix)
    dssel.init_files(overwrite=True, keep_features=False)
    log('File initialized, writing tables...')
    dssel.store_tables(shots=True, features=True)

    try:
        dssel.open_stacks(readonly=False)
        log('Writing stack data...')
        dssel.store_stacks(overwrite=True, progress_bar=False)

    except Exception as err:
        log('Cumulative processing failed.')
        raise err

    finally:
        dssel.close_stacks()
        log('Cumulation done.')

    return dssel.files
Exemple #3
0
def from_raw(fn, opt: PreProcOpts):

    if isinstance(fn, list) and len(fn) == 1:
        fn = fn[0]

    def log(*args):
        if not (opt.verbose or any([isinstance(err, Exception)
                                    for e in args])):
            return
        if isinstance(fn, list):
            dispfn = os.path.basename(fn[0]) + ' etc.'
        else:
            dispfn = os.path.basename(fn)
        idstring = '[{} - {} - from_raw] '.format(
            datetime.datetime.now().time(), dispfn)
        print(idstring, *args)

    t0 = time()
    dsraw = Dataset().from_list(fn)

    reference = imread(opt.reference)
    pxmask = imread(opt.pxmask)

    os.makedirs(opt.scratch_dir, exist_ok=True)
    os.makedirs(opt.proc_dir, exist_ok=True)
    dsraw.open_stacks(readonly=True)

    if opt.aggregate:
        dsagg = dsraw.aggregate(file_suffix=opt.agg_file_suffix,
                                new_folder=opt.proc_dir,
                                force_commensurate=False,
                                how={'raw_counts': 'sum'},
                                query=opt.agg_query)
    else:
        dsagg = dsraw.get_selection(opt.agg_query,
                                    new_folder=opt.proc_dir,
                                    file_suffix=opt.agg_file_suffix)

    log(f'{dsraw.shots.shape[0]} raw, {dsagg.shots.shape[0]} aggregated/selected.'
        )

    if opt.rechunk is not None:
        dsagg.rechunk_stacks(opt.rechunk)

    # Saturation, flat-field and dead-pixel correction
    if opt.correct_saturation:
        stack_ff = proc2d.apply_flatfield(
            proc2d.apply_saturation_correction(dsagg.raw_counts,
                                               opt.shutter_time,
                                               opt.dead_time), reference)
    else:
        stack_ff = proc2d.apply_flatfield(dsagg.raw_counts, reference)

    stack = proc2d.correct_dead_pixels(stack_ff,
                                       pxmask,
                                       strategy='replace',
                                       replace_val=-1,
                                       mask_gaps=True)

    # Stack in central region along x (note that the gaps are not masked this time)
    xrng = slice((opt.xsize - opt.com_xrng) // 2,
                 (opt.xsize + opt.com_xrng) // 2)
    stack_ct = proc2d.correct_dead_pixels(stack_ff[:, :, xrng],
                                          pxmask[:, xrng],
                                          strategy='replace',
                                          replace_val=-1,
                                          mask_gaps=False)

    # Define COM threshold as fraction of highest pixel (after discarding some too high ones)
    thr = stack_ct.max(axis=1).topk(10, axis=1)[:, 9].reshape(
        (-1, 1, 1)) * opt.com_threshold
    com = proc2d.center_of_mass2(
        stack_ct, threshold=thr) + [[(opt.xsize - opt.com_xrng) // 2, 0]]

    # Lorentzian fit in region around the found COM
    lorentz = compute.map_reduction_func(stack,
                                         proc2d.lorentz_fast,
                                         com[:, 0],
                                         com[:, 1],
                                         radius=opt.lorentz_radius,
                                         limit=opt.lorentz_maxshift,
                                         scale=7,
                                         threads=False,
                                         output_len=4)
    ctr = lorentz[:, 1:3]

    # calculate the centered image by shifting and padding with -1
    centered = proc2d.center_image(stack,
                                   ctr[:, 0],
                                   ctr[:, 1],
                                   opt.xsize,
                                   opt.ysize,
                                   -1,
                                   parallel=True)

    # add the new stacks to the aggregated dataset
    alldata = {
        'center_of_mass':
        com,
        'lorentz_fit':
        lorentz,
        'beam_center':
        ctr,
        'centered':
        centered,
        'pxmask_centered': (centered != -1).astype(np.uint16),
        'adf1':
        proc2d.apply_virtual_detector(centered, opt.r_adf1[0], opt.r_adf1[1]),
        'adf2':
        proc2d.apply_virtual_detector(centered, opt.r_adf2[0], opt.r_adf2[1])
    }
    for lbl, stk in alldata.items():
        print('adding', lbl, stk.shape)
        dsagg.add_stack(lbl, stk, overwrite=True)

    # make the files and crunch the data
    try:
        dsagg.init_files(overwrite=True)
        dsagg.store_tables(shots=True, features=True)
        dsagg.open_stacks(readonly=False)
        dsagg.delete_stack(
            'raw_counts',
            from_files=False)  # we don't need the raw counts in the new files
        dsagg.store_stacks(
            overwrite=True,
            progress_bar=False)  # this does the actual calculation

        log('Finished first centering', dsagg.centered.shape[0], 'shots after',
            time() - t0, 'seconds')

    except Exception as err:
        log('Raw processing failed.', err)
        raise err

    finally:
        dsagg.close_stacks()
        dsraw.close_stacks()

    return dsagg.files