Esempio n. 1
0
def read_nsls2_fxi18_h5_mpi(scan_id, data_path=".", size=None):
    global mpi_size
    filename = str(data_path / f"fly_scan_id_{scan_id}.h5")
    # get data size first, then load only data required, split on projs axis, 0
    with h5py.File(filename, "r") as f:
        projs_shape = f["/img_tomo"].shape
    distribution = Distribution.default(projs_shape, mpi_size)
    proj_slc = (distribution.offsets[mpi_rank],
                distribution.offsets[mpi_rank] + distribution.sizes[mpi_rank],
                1)
    logger.info(f"{projs_shape}, {mpi_size}, {proj_slc}")

    #projs, flats, darks, thetas = dxchange.read_nsls2_fxi18_h5(filename, proj_slc)
    projs, flats, darks, thetas = read_nsls2_fxi18_h5(filename, proj_slc)
    # create mpi_projs and mpi_thetas.  The flats and darks are needed by all nodes.
    # NOTE: flats and darks should be uint16, but sometimes are float32
    flats = np.require(flats, dtype=np.uint16)
    darks = np.require(darks, dtype=np.uint16)
    if size is not None:
        slc1 = utils.center_slice(projs, size, axis=1)
        slc2 = utils.center_slice(projs, size, axis=2)
        projs = np.require(projs[:, slc1, slc2], requirements="C")
        flats = np.require(flats[:, slc1, slc2], requirements="C")
        darks = np.require(darks[:, slc1, slc2], requirements="C")
    mpi_projs = MpiArray(projs, distribution=distribution)
    mpi_thetas = MpiArray(thetas, distribution=distribution)
    return mpi_projs, flats, darks, mpi_thetas
def grid_search_ss_split(atlas, I, W, R, u, bounds=[-20, 20]):
    ss_offset = W.distribution.offsets[rank]
    print('rank, W.shape, offset:', rank, W.shape, W.arr.shape, ss_offset,
          W.distribution.offsets[rank])

    def fun(ui, uj, i, j, w):
        ss = np.rint(-R[:, 0] + ui + i + ss_offset).astype(np.int)
        fs = np.rint(-R[:, 1] + uj + j).astype(np.int)
        m = (ss > 0) * (ss < atlas.shape[0]) * (fs > 0) * (fs < atlas.shape[1])
        forw = atlas[ss[m], fs[m]]
        m1 = (forw > 0) * (w[m, i, j] > 0)
        n = np.sum(m1)
        if n > 0:
            #err = np.sum(m1*(forw - I.arr[m, i, j]/w[m, i, j])**2 ) / n
            err = np.sum(m1 * (forw * w[m, i, j] - I.arr[m, i, j])**2) / n
        else:
            err = 1e100
        return err

    # define the search window
    k = np.arange(int(bounds[0]), int(bounds[1]) + 1, 1)
    k, l = np.meshgrid(k, k, indexing='ij')
    kls = np.vstack((k.ravel(), l.ravel())).T

    # define the pupil idices
    ss = np.arange(W.arr.shape[-2])
    fs = np.arange(W.arr.shape[-1])

    if len(W.arr.shape) == 2:
        w = np.array([W.arr for i in range(I.shape[0])])
    else:
        w = W.arr

    ss_min = comm.allreduce(np.max(ss), op=MPI.MIN)
    errs = np.empty((len(kls), ), dtype=np.float64)
    u_out = MpiArray(u.arr, axis=1)
    for i in ss:
        print(rank, i, i + ss_offset)
        sys.stdout.flush()
        for j in fs:
            errs.fill(1e100)
            for k, kl in enumerate(kls):
                if np.any(w[:, i, j]) > 0:
                    errs[k] = fun(u.arr[0, i, j] + kl[0],
                                  u.arr[1, i, j] + kl[1], i, j, w)

            k = np.argmin(errs)
            u_out.arr[:, i, j] += kls[k]

        # output every 10 ss pixels
        if i % 1 == 0 and i <= ss_min and callback is not None:
            ug = u_out.gather()
            callback(ug)

    return u_out
Esempio n. 3
0
def find_angle_mpi(mpi_rec, axis=0):
    from mpiarray import MpiArray
    mpi_rec_diff = mpi_rec.copy(deep=True)
    # take different between layers to measure angle
    rec_diff = mpi_rec_diff.scatter(0, 5)  # add padding
    rec_diff[rec_diff < 0] = 0
    # smooth data between IC layers and take difference
    for i in range(rec_diff.shape[1]):
        filters.gaussian_filter(rec_diff[:, i], (9, 9), output=rec_diff[:, i])
        if i > 0:
            rec_diff[:, i - 1] -= rec_diff[:, i]
    rec_diff = mpi_rec_diff.scatter(0)  # remove padding

    # now collapse axis to correct angle (already dist on zero)
    if axis == 0:
        layer_diff = np.sum(rec_diff, axis=0, keepdims=True)
        layer_diff = MpiArray(layer_diff).allgather()
        layer_diff = np.sum(layer_diff, axis=axis)
        layer_diff /= rec_diff.shape[0]
    else:
        layer_diff = np.sum(rec_diff, axis=axis)
        layer_diff = MpiArray(layer_diff).allgather()
        layer_diff = np.rot90(layer_diff,
                              axes=(1, 0))  #TODO: currently specific to axis=2

    # smooth out layer_diff
    filters.gaussian_filter(layer_diff, (3, 3), output=layer_diff)
    # remove sides
    side_trim = layer_diff.shape[0] // 4
    layer_diff = layer_diff[side_trim:-side_trim]

    if mpi_rec.mpi_rank == 0:
        # TODO: need to return layer_diff or ignore, don't write to disk
        dxchange.write_tiff(layer_diff,
                            fname='layer_diff_%d.tiff' % axis,
                            overwrite=True)
        angle_deg = find_angle_from_layer_diff(layer_diff,
                                               axis=1)  #FIXME: axis
        # find where the layer starts and stops
        rotated_layer_diff = rotate(layer_diff, angle_deg, reshape=False)
        start, end = find_extent_from_layer_diff(rotated_layer_diff, axis=1)
        # NOTE: don't forget to add back on the trimmed off amount from earlier!
        start += side_trim
        end += side_trim
    else:
        angle_deg, start, end = None, None, None
    angle_deg = mpi_rec.comm.bcast(angle_deg, root=0)
    start = mpi_rec.comm.bcast(start, root=0)
    end = mpi_rec.comm.bcast(end, root=0)
    return angle_deg, start, end
Esempio n. 4
0
def normalize_mpi(mpi_projs, flats, darks):
    projs = mpi_projs.scatter(0)
    # use median instead of Tomopy mean to remove outliers
    flats = np.median(flats, axis=0, keepdims=True)
    darks = np.median(darks, axis=0, keepdims=True)
    projs = tomopy.normalize(projs, flats, darks)
    return MpiArray(projs)
Esempio n. 5
0
 def test_fromlocalarray(self):
     # load from local array and check values
     size = 7
     for padding in (0, 1, 2):
         shape = (mpi_size*size, 10, 8)
         arr = np.random.rand(*shape)
         local_arr = arr[max(0, mpi_rank*size-padding):(mpi_rank+1)*size+padding]
         mpiarray = MpiArray(local_arr, padding=padding)
         self.check_fields(arr, mpiarray)
def pos_refine_all(atlas,
                   W,
                   R,
                   u,
                   I,
                   bounds=[(0., 10.), (0., 10.)],
                   max_iters=1000,
                   sub_pixel=False):
    R_out = []
    err_out = []
    for r, frame in zip(R.arr, I.arr):
        r_out, err = pos_refine_grid(atlas, W, r, u, frame, bounds, max_iters,
                                     sub_pixel)
        R_out.append(r_out.copy())
        err_out.append(err)
        print(r, r_out)
        sys.stdout.flush()

    Rs = MpiArray(np.array(R_out), axis=0)
    errs = MpiArray(np.array(err_out), axis=0)
    return Rs, errs
Esempio n. 7
0
 def load_to_scattered(self, shape, axis=0, padding=0):
     # load array to scattered form.  This loads noncontiguous for testing purposes.
     arr, mpiarray = self.load_fromglobalarray(shape)
     local_arr = mpiarray.scatter(axis, padding=padding)
     if len(local_arr.shape) > 1:
         # swap axis(0,1), then make contiguous, then swap back
         new_local_arr = np.swapaxes(local_arr, 0, 1)
         new_local_arr = np.require(new_local_arr, requirements="C")
         new_local_arr = np.swapaxes(new_local_arr, 0, 1)
         assert_array_equal(local_arr, new_local_arr) #sanity check
         local_arr = new_local_arr
         mpiarray = MpiArray(local_arr, distribution=mpiarray.distribution)
     return arr, mpiarray
def get_input():
    args, params = cmdline_config_cxi_reader.get_all(
        'pos_refine',
        'update the sample translations according to a least squares minimisation procedure',
        exclude=['frames'])
    params = params['pos_refine']

    # frames, split by frame no.
    roi = params['roi']
    roi = (params['good_frames'], slice(roi[0], roi[1]), slice(roi[2], roi[3]))
    params['frames'] = MpiArray_from_h5(args.filename,
                                        params['frames'],
                                        axis=0,
                                        dtype=np.float64,
                                        roi=roi)
    #def MpiArray_from_h5(fnam, path, axis=0, dtype=None, roi = None):
    #params['frames'] = params['frames'][params['good_frames']]

    if rank != 0:
        params['R_ss_fs'] = None

    # offset positions
    if rank == 0 and params['pixel_shifts'] is not None:
        params['R_ss_fs'] += fm.steps_offset(params['R_ss_fs'],
                                             params['pixel_shifts'])

    params['R_ss_fs'] = MpiArray(params['R_ss_fs'])
    params['R_ss_fs'].scatter(axis=0)

    # set masked pixels to negative 1
    for i in range(params['frames'].arr.shape[0]):
        params['frames'].arr[i][~params['mask']] = -1

    params['whitefield'][~params['mask']] = -1

    if params['pixel_shifts'] is None:
        params['pixel_shifts'] = np.zeros((2, ) + params['whitefield'].shape,
                                          dtype=np.float64)

    # add a regularization factor
    shape = params['whitefield'].shape
    reg = mk_reg(shape, params['reg'])

    return args, params, reg
Esempio n. 9
0
 def load_fromglobalarray(self, shape=(16, 16, 8)):
     # only rank zero provides the global_arr for loading
     if mpi_rank == 0:
     # load noncontiguous by default
         if len(shape) <= 1:
             # single dimension, just load it
             arr = np.random.rand(*shape)
         else:
             # swap axis 0 and 1 to make noncontiguous (unless axis has length of 1...)
             noncontiguous_shape = (shape[1], shape[0])
             if len(shape) > 2:
                 noncontiguous_shape += shape[2:]
             arr = np.random.rand(*noncontiguous_shape)
             arr = np.swapaxes(arr, 0, 1)
     else:
         arr = None
     mpiarray = MpiArray(arr)
     # share array to all MPI nodes
     arr = comm.bcast(arr)
     return arr, mpiarray
Esempio n. 10
0
def tomopy_recon_mpi(mpi_projs,
                     mpi_thetas,
                     center_offset,
                     algorithm="gridrec",
                     start_z=None,
                     end_z=None,
                     pad=None,
                     **kwargs):
    # perform MPI recon with padding on the X-axis to support flat samples
    # return mpi_rec
    # generate sinograms with padding
    mpi_sinos = utils_mpi.create_sinos_mpi(mpi_projs)
    sinos = mpi_sinos.scatter(0)
    sinos = pad_sinos(sinos, pad)
    gthetas = mpi_thetas.allgather()  #global thetas
    center = sinos.shape[2] // 2 + center_offset
    rec = tomopy.recon(sinos, gthetas, center, True, algorithm, **kwargs)
    rec = remove_pad_rec(rec, pad)
    if start_z is not None and end_z is not None:
        rec = rec[:, start_z:end_z, :]
    mpi_rec = MpiArray(rec, distribution=mpi_sinos.distribution)
    return mpi_rec
Esempio n. 11
0
def extract_layers_mpi(mpi_rec, peak_template=None):
    # detect and extract layers from reconstruction
    # NOTE: IC should already be aligned to reconstruction grid

    # smooth along layer of IC and take difference between voxels
    # max difference should be the boundary between layers
    line = np.zeros((mpi_rec.shape[1] - 1, ), np.float32)
    rec = mpi_rec.scatter(0, 5)  # add padding
    # smooth data between IC layers and take difference
    layer = None
    # average_intensity = None
    for i in range(rec.shape[1]):
        prev_layer = layer
        layer = np.copy(rec[:, i, :])
        layer[layer < 0] = 0
        layer = filters.gaussian_filter(layer, (5, 5))
        # remove differences in overall intensity
        # if average_intensity is None:
        #    average_intensity = np.mean(layer)
        # layer *= average_intensity / np.mean(layer)
        if prev_layer is not None:
            # remove padding from layers before taking diff
            offset_padding = mpi_rec.unpadded_offset - mpi_rec.offset
            size_padding = mpi_rec.size - mpi_rec.unpadded_size - offset_padding
            slc = np.s_[offset_padding:-size_padding]
            line[i - 1] = np.sum(layer[slc] - prev_layer[slc])
    filters.gaussian_filter(line, 2, output=line)
    rec = mpi_rec.scatter(0)  # remove padding
    # sum line for all nodes
    total = np.zeros_like(line)
    mpi_rec.comm.Reduce([line, MPI.FLOAT], [total, MPI.FLOAT],
                        op=MPI.SUM,
                        root=0)
    if mpi_rec.mpi_rank == 0:
        # calculate peaks and then share to all nodes
        trim = 10
        total[0:trim] = 0
        total[-trim:] = 0
        peaks = find_peaks(total, peak_template)
        # if DEBUG:
        #     plt.figure()
        #     plt.plot(total)
        #     plt.scatter(peaks, np.take(total, peaks))
        #     plt.savefig("line.tiff")
        logger.info("peaks from %d layers detected: %s" %
                    (len(peaks), str(peaks)))
        logger.info("peaks spacing: %s" % (str(peaks[1:] - peaks[:-1])))
    else:
        peaks = None
    peaks = mpi_rec.comm.bcast(peaks, root=0)

    layers = np.zeros((rec.shape[0], peaks.shape[0] - 1, rec.shape[2]),
                      dtype=rec.dtype)
    for p in reversed(range(peaks.shape[0] -
                            1)):  # reversed, since highest metal is first
        if peaks[p + 1] - peaks[p] >= 3:
            layers[:, p] = np.mean(rec[:, peaks[p] + 1:peaks[p + 1] - 1],
                                   axis=1)
        else:
            layers[:, p] = rec[:, (peaks[p] + peaks[p + 1]) // 2]
    mpi_layers = MpiArray(layers, axis=0)
    return mpi_layers
Esempio n. 12
0
def global_median_of_images_mpi(mpi_images):
    slices = mpi_images.scatter(1)
    image = np.median(slices, axis=0)
    return MpiArray(image, axis=0).allgather()
def get_input():
    args, params = cmdline_config_cxi_reader.get_all(
        'update_pixel_map',
        'update the pixel shifts according to a least squares minimisation procedure',
        exclude=['frames', 'whitefield', 'mask'])
    params = params['update_pixel_map']

    # split by ss pixels
    ####################

    # special treatment for frames
    roi = params['roi']
    roi = (params['good_frames'], slice(roi[0], roi[1]), slice(roi[2], roi[3]))

    # easy for the mask
    params['frames'] = MpiArray_from_h5(args.filename,
                                        params['frames'],
                                        axis=1,
                                        roi=roi,
                                        dtype=np.float64)
    params['mask'] = MpiArray_from_h5(args.filename,
                                      params['mask'],
                                      axis=0,
                                      roi=roi[1:],
                                      dtype=np.bool)

    with h5py.File(args.filename, 'r') as f:
        shape = f[params['whitefield']].shape
    if len(shape) == 2:
        params['whitefield'] = MpiArray_from_h5(args.filename,
                                                params['whitefield'],
                                                axis=0,
                                                roi=roi[1:],
                                                dtype=np.float64)
        params['whitefield'].arr[~params['mask'].arr] = -1
    else:
        params['whitefield'] = MpiArray_from_h5(args.filename,
                                                params['whitefield'],
                                                axis=1,
                                                roi=roi,
                                                dtype=np.float64)
        for i in range(params['whitefield'].arr.shape[0]):
            params['whitefield'].arr[i][~params['mask'].arr] = -1

    # set masked pixels to negative 1
    for i in range(params['frames'].arr.shape[0]):
        params['frames'].arr[i][~params['mask'].arr] = -1

    # offset positions
    if params['pixel_shifts'] is not None:
        params['R_ss_fs'] += fm.steps_offset(params['R_ss_fs'],
                                             params['pixel_shifts'])

    # special treatment for the pixel_shifts
    if params['pixel_shifts'] is None and rank == 0:
        params['pixel_shifts'] = np.zeros(
            (2, ) + params['whitefield'].shape[-2:], dtype=np.float64)

    if rank != 0:
        params['pixel_shifts'] = None

    params['pixel_shifts'] = MpiArray(params['pixel_shifts'])
    params['pixel_shifts'].scatter(axis=1)

    return args, params
def grid_search_ss_split_sub_pix(atlas,
                                 I,
                                 W,
                                 R,
                                 u,
                                 bounds=[-1, 1],
                                 max_iters=100):
    ss_offset = W.distribution.offsets[rank]

    def fun(x, i=0, j=0, w=0, uu=0):
        forw = np.empty((R.shape[0], ), dtype=np.float)
        fm.sub_pixel_atlas_eval(atlas, forw,
                                -R[:, 0] + x[0] + uu[0] + i + ss_offset,
                                -R[:, 1] + x[1] + uu[1] + j)

        m = (forw > 0) * (w > 0)
        n = np.sum(m)
        if n > 0:
            err = np.sum((forw[m] * w[m] - I.arr[m, i, j])**2) / n
        else:
            err = -1
        return err

    if len(W.arr.shape) == 2:
        w = np.array([W.arr for i in range(I.shape[0])])
    else:
        w = W.arr

    # define the pupil idices
    ss = np.arange(W.arr.shape[-2])
    fs = np.arange(W.arr.shape[-1])

    ss_min = comm.allreduce(np.max(ss), op=MPI.MIN)
    u_out = MpiArray(u.arr, axis=1)
    for i in ss:
        for j in fs:
            x0 = np.array([0., 0.])
            err0 = fun(x0, i, j, w[:, i, j], u.arr[:, i, j])
            options = {'maxiter': max_iters, 'eps': 0.1, 'xatol': 0.1}
            from scipy.optimize import minimize
            res = minimize(fun,
                           x0,
                           bounds=[bounds, bounds],
                           args=(i, j, w[:, i, j], u.arr[:, i, j]),
                           options={
                               'maxiter': max_iters,
                               'eps': 0.1,
                               'xatol': 0.1
                           })
            if rank == 0 and j == fs[fs.shape[0] // 2]:
                print(err0, res)
                sys.stdout.flush()
            if res.fun > 0:
                u_out.arr[:, i, j] += res.x

        # output every 10 ss pixels
        if i % 1 == 0 and i <= ss_min and callback is not None:
            ug = u_out.gather()
            callback(ug)

    return u_out
Esempio n. 15
0
def do_work(out_path, scan_id, center_offset, mpi_projs, flats, darks,
            mpi_thetas):

    logger.info("Flat Field Correction")
    mpi_projs = normalize_mpi(mpi_projs, flats, darks)
    #utils_mpi.write_stack_mpi(out_path/"flat", mpi_projs)
    del flats
    del darks

    logger.info("Outlier Removal")
    # TODO: base parameters on clean simulation data! - might need fill
    projs = mpi_projs.scatter(0)
    # TODO: put back in, I think it is causing issues right now...
    tomopy.remove_outlier(projs, 0.1, 5, ncore=ncore, out=projs)
    #tomopy.remove_outlier_cuda(projs, 0.1, 5, ncore, out=projs)
    np.clip(projs, 1E-6, 1 - 1E-6, projs)

    # TODO: distortion correction factor?

    # TODO: ring removal?

    # # flat field change correction
    # remove_low_frequency_mpi(mpi_projs)
    # utils_mpi.write_stack_mpi(out_path/"low_frequency_removed", mpi_projs)

    # bulk Si intensity correction
    # removes constant absorption contribution from bulk Si, and mounting material
    # TODO: base parameters on clean simulation data! - will need fill
    # TODO: alternatively, refine result after good recon - with theta offset
    target_transmission = 0.80
    logger.info(f"Setting target transmission to {target_transmission}")
    set_target_transmission_mpi(mpi_projs, mpi_thetas, target_transmission)
    projs = mpi_projs.scatter(0)
    np.clip(projs, 1E-6, 1 - 1E-6, projs)
    utils_mpi.write_stack_mpi(out_path / "constant_transmission", mpi_projs)

    # center finding - manual for now?
    if center_offset is None:
        logger.info("Finding center")
        # algorithm = "SART"
        # pixel_size = 2 * 0.000016 #16nm bin 1
        # options = {"PixelWidth": pixel_size,
        #            "PixelHeight": pixel_size,
        #            "windowFOV": False,
        #            "archDir": out_path,
        #            "_mpi_rank": mpi_rank,
        #            }
        # alg_params = {"N_iter": 1,
        #               "N_subsets": 20,
        #               "nonnegativityConstraint": True,
        #               "useFBPasSeedImage": False,
        #               # "Preconditioner": "RAMP",
        #               # "beta": 2e-7,
        #               # "p": 1,
        #               # "delta": 1/20, # delta sets edge strength (difference between regions divide by ten)
        #               # "inverseVarianceExponent": 1.0, # set to 1 to include noise model
        #               # "other": 3, #convergence of low frequencies
        #               }
        # # load data into LTT, then find the center before recon
        # ltt_tomopy.initialize_recon(sinos, thetas, xcenter, True, algorithm, options, ncore=ncore)
        # center = align_tomo.find_center_ltt(lambda c: ltt_tomopy.preview(center=c, algorithm=algorithm, sinogram_order=True, close=False, options=options, alg_params=alg_params, ncore=ncore), xcenter, 0.1, ratio=0.8)
        # ltt_tomopy.recon_close()
        logger.info("Padding sinos for center finding")
        mpi_sinos = utils_mpi.create_sinos_mpi(mpi_projs, ncore)
        sinos = mpi_sinos.scatter(0)
        sinos = pad_sinos(sinos)
        mpi_sinos = MpiArray(sinos)
        gthetas = mpi_thetas.allgather()
        xcenter = sinos.shape[2] // 2
        cen_range = (xcenter - 20, xcenter + 20, 0.5)
        if mpi_rank == mpi_size // 2:
            tomopy.write_center(sinos,
                                gthetas,
                                out_path / ("center"),
                                cen_range,
                                sinogram_order=True)
        del mpi_sinos, sinos
        import sys
        comm.Barrier()
        sys.exit()
        # center_offset = mpi_projs.shape[2]//2-center
        # mpi_projs.comm.Barrier() #for printing
        # print(f"{mpi_projs.mpi_rank}: center={center} offset={center_offset}")
    #center = xcenter + center_offset

    # Shift correction?
    # use MPI and binning for speed
    # use LTT for all recon
    # define recon extent with extra X and less Z

    # Quick recon - LTT with SART or other? (can't use FBP)
    algorithm = "gridrec"
    options = {
        "filter_name": "parzen",
    }
    logger.info(f"Finding layer alignment within volume using {algorithm}")
    mpi_rec = tomopy_recon_mpi(mpi_projs,
                               mpi_thetas,
                               center_offset,
                               algorithm,
                               ncore=ncore,
                               **options)
    utils_mpi.write_stack_mpi(out_path / ("quick_" + algorithm), mpi_rec)
    theta_deg, start_z1, end_z1 = align_layers.find_angle_mpi(mpi_rec, 0)
    logger.info(f"Theta offset angle {theta_deg:0.2} deg")
    # find phi angle (axis 2)
    phi_deg, start_z2, end_z2 = align_layers.find_angle_mpi(mpi_rec, 2)
    logger.info(f"Phi offset angle {phi_deg:0.2} deg")
    start_z = min(start_z1, start_z2)
    end_z = max(end_z1, end_z2)
    # add buffer for start and end
    start_z = max(start_z - 20, 0)
    end_z = min(end_z + 20, mpi_rec.shape[1] - 1)
    #FIXME: override start and end
    start_z = 0
    end_z = mpi_rec.shape[1] - 1
    logger.info(f"Layer extent: {start_z} - {end_z}")

    # change theta with correction for next reconstruction
    thetas = mpi_thetas.scatter(0)
    thetas += np.deg2rad(theta_deg)
    # modify projections to remove phi angle
    # TODO: combine with stage shift code
    projs = mpi_projs.scatter(0)
    align_layers.apply_phi_correction(projs, thetas, phi_deg, projs)

    # Quick aligned recon
    algorithm = "gridrec"
    options = {
        "filter_name": "parzen",
    }
    logger.info("Quick Tomopy Recon")
    mpi_rec = tomopy_recon_mpi(mpi_projs,
                               mpi_thetas,
                               center_offset,
                               algorithm,
                               ncore=ncore,
                               **options)
    rec = mpi_rec.scatter(0)
    rec = rec[:, start_z:end_z, :]
    mpi_rec = MpiArray(rec)
    utils_mpi.write_stack_mpi(out_path / algorithm, mpi_rec)
    del mpi_rec, rec

    # Aligned recon
    # iterative recon with extra recon space in X and a restricted Z axis
    algorithm = "SART"  #"RDLS"#"DFM"#"ASD-POCS"#"SART"#"FBP"
    logger.info(f"Reconstructing aligned layers using {algorithm}")
    mpi_sinos = utils_mpi.create_sinos_mpi(mpi_projs, ncore)
    #utils_mpi.write_stack_mpi(out_path/"sinos", mpi_sinos)
    sinos = mpi_sinos.scatter(0)
    # add padding to recon - fixes cupping effect
    xrecpadding = sinos.shape[2] // 2
    pixel_size = 2 * 0.000016  # 16nm bin 1
    options = {
        "PixelWidth": pixel_size,
        "PixelHeight": pixel_size,
        "ryoffset": start_z,
        "ryelements": end_z - start_z,
        "windowFOV": False,
        "rxelements": sinos.shape[2] + 2 * xrecpadding,
        "rxoffset": -xrecpadding,
        "_mpi_rank": mpi_rank,
    }
    alg_params = {
        "N_iter": 50,
        "nonnegativityConstraint": False,
        "useFBPasSeedImage": False,
        #"Preconditioner": "RAMP",
        #"descentType": "CG",#"GD",
        #"beta": 2e-7,
        #"p": 1,
        #"delta": 20/20, # delta sets edge strength (difference between regions divide by ten)
        #"inverseVarianceExponent": 1.0, # set to 1 to include noise model
        #"other": 3, #convergence of low frequencies
    }
    #TODO: add support to add overlap in future with updates between iterations (see xray_trust6.py)
    gthetas = mpi_thetas.allgather()  #global thetas
    center = sinos.shape[2] // 2 + center_offset
    if gthetas[1] < gthetas[0]:
        # decreasing angle, LTT doesn't support, switch data around
        # TODO: load in reversed order?
        gthetas = gthetas[::-1]
        sinos[:] = sinos[:, ::-1, :]
    rec = ltt_tomopy.recon(sinos,
                           gthetas,
                           center,
                           True,
                           algorithm,
                           alg_params,
                           options,
                           ncore=ncore)
    rec = rec[:, :, xrecpadding:xrecpadding + sinos.shape[2]]
    mpi_rec = MpiArray(rec, distribution=mpi_sinos.distribution)
    utils_mpi.write_stack_mpi(out_path / algorithm, mpi_rec)

    # Neural network processing?

    # Extract layers
    # Use template if available
    logger.info("Extracting Layers")
    mpi_layers = align_layers.extract_layers_mpi(mpi_rec)
    align_layers.write_layers_to_file_mpi(mpi_layers, "layers")

    logger.info(f"Finished {scan_id}")
filename = 'rs64_241proj_5X_9200eV_2_.h5'

os.chdir(working_dir)
start = time.time()

# Read HDF5 file.
logger.info("Reading data from H5 file %s" % filename)
#TODO: read directly into different mpi processes
if rank == 0:
    # read data into root node
    proj, flat, dark, theta = dxchange.read_aps_32id(filename, dtype=np.float32)
else:
    proj, flat, dark, theta = None, None, None, None

# create MpiArray from Proj data
proj = MpiArray.fromglobalarray(proj)
proj.scatter(0)
proj.arr = None # remove full array to save memory

# share flats, darks, and theta to all MPI nodes
flat = comm.bcast(flat, root=0)
dark = comm.bcast(dark, root=0)
theta = comm.bcast(theta, root=0)

# Flat field correct data
logger.info("Flat field correcting data")
proj.scatter(0)
tomopy.normalize(proj.local_arr, flat, dark, ncore=1, out=proj.local_arr)
np.clip(proj.local_arr, 1e-6, 1.0, proj.local_arr)
del flat, dark
    # refine the positions
    m = params['max_step']
    params['R_ss_fs'], errs = pos_refine_all(atlas,
                                             params['whitefield'],
                                             params['R_ss_fs'],
                                             params['pixel_shifts'],
                                             params['frames'],
                                             bounds=[(-m, m), (-m, m)],
                                             max_iters=params['max_iters'],
                                             sub_pixel=params['sub_pixel'])

    errs2 = errs.allgather()
    em = errs2.mean()
    weights = np.abs(errs.arr - em)
    weights = 1. - weights / weights.max() + 0.2
    weights = MpiArray(weights, axis=0)

    #atlas  = build_atlas_distortions_MpiArray(params['frames'],
    #                                          params['whitefield'],
    #                                          params['R_ss_fs'],
    #                                          params['pixel_shifts'],
    #                                          reg=reg, weights = weights.arr)

    params['R_ss_fs'] = params['R_ss_fs'].gather()
    errs = errs.gather()
    weights = weights.gather()

    # real-time output
    if rank == 0:
        out = {'R_ss_fs': params['R_ss_fs'], 'errs': errs, 'weights': weights}
        cmdline_config_cxi_reader.write_all(params, args.filename, out)