Exemplo n.º 1
0
def align_3D_coarse_axes(img_ref, img1, circle_mask_ratio=0.6, axes=0, shift_flag=1):
    '''
    Aligning the reconstructed tomo with assigned 3D reconstruction along given axis. It will project the 3D data along given axis to find the shifts

    Inputs:
    -----------
    ref: 3D array

    data: 3D array need to align

    axis: int
        along which axis to project the 3D reconstruction to find image shifts 
        0, or 1, or 2
    
    Output:
    ----------------
    aligned tomo, shfit_matrix

    '''

    img_tmp = img_ref.copy()
    if circle_mask_ratio < 1:
        img_ref_crop = pyxas.circ_mask(img_tmp, axis=0, ratio=circle_mask_ratio, val=0)   
    else:
        img_ref_crop = img_tmp.copy() 
    s = img_ref_crop.shape
    stack_range = [int(s[0]*(0.5-circle_mask_ratio/2)), int(s[0]*(0.5+circle_mask_ratio/2))]
    prj0 = np.sum(img_ref_crop[stack_range[0]:stack_range[1]], axis=axes)

    img_tmp = img1.copy()    
    if circle_mask_ratio < 1:
        img_raw_crop = pyxas.circ_mask(img_tmp, axis=0, ratio=circle_mask_ratio, val=0)
    else:
        img_raw_crop = img_tmp.copy()
    prj1 = np.sum(img_raw_crop[stack_range[0]:stack_range[1]], axis=axes)
    
    sr = StackReg(StackReg.TRANSLATION)
    tmat = sr.register(prj0, prj1)
    r = -tmat[1, 2]
    c = -tmat[0, 2]  

    if axes == 0:
        shift_matrix = np.array([0, r, c])
    elif axes == 1:
        shift_matrix = np.array([r, 0, c])
    elif axes == 2:
        shift_matrix = np.array([r, c, 0])
    else:
        shift_matrix = np.array([0, 0, 0])
    if shift_flag:
        img_ali = pyxas.shift(img1, shift_matrix, order=0)
        return img_ali, shift_matrix    
    else:
        return shift_matrix
Exemplo n.º 2
0
def align_3D_tomo_file_mpi_specific(file_save_path, files_recon=[], files_ref='', binning=1, circle_mask_ratio=0.8, file_type='.h5', align_coarse=1, align_method=1, hdf_attr='img', num_cpu=4):
    '''
    align_method: 
        1:  old method
        2:  3D cross-correlation
    '''
    from multiprocessing import Pool, cpu_count
    from functools import partial
    num_cpu = min(round(cpu_count() * 0.8), num_cpu)
    print(f'align_3D_tomo using {num_cpu:2d} CPUs')
    # save ref image
    file_path = os.path.abspath(file_save_path)
    img_ref, scan_id, X_eng = get_tomo_image(files_ref, file_type, hdf_attr)
    if binning > 1:
        img_ref = pyxas.bin_image(img_ref, binning)
    if circle_mask_ratio < 1:
            img_ref = pyxas.circ_mask(img_ref, axis=0, ratio=circle_mask_ratio)
    if align_method == 1: 
        img_ref = pyxas.move_3D_to_center(img_ref, circle_mask_ratio=circle_mask_ratio)
        
    fn_save = f'{file_save_path}/ali_recon_{scan_id}_bin_{binning}.h5'
    pyxas.save_hdf_file(fn_save, 'img', img_ref.astype(np.float32), 'scan_id', scan_id, 'X_eng', X_eng)
    # start align
    pool = Pool(num_cpu)
    pool.map(partial(align_3D_tomo_file_mpi_sub, 
                    img_ref=img_ref, 
                    file_path=file_path, 
                    binning=binning, 
                    circle_mask_ratio=circle_mask_ratio, 
                    file_type=file_type,
                    align_coarse=align_coarse,
                    align_method=align_method,
                    hdf_attr=hdf_attr), 
            files_recon)        
    pool.close()    
Exemplo n.º 3
0
def align_3D_tomo_file_specific(file_save_path='.', files_recon=[], files_ref='', binning=1, circle_mask_ratio=0.9, file_type='.h5', align_coarse=1, align_method=1, hdf_attr='img'):    
    '''
    align_method: 
        1:  old method
        2:  3D cross-correlation
    '''
    import time
    file_path = os.path.abspath(file_save_path)
    img_ref, scan_id, X_eng = get_tomo_image(files_ref, file_type, hdf_attr)
        
    if binning > 1:
        img_ref = pyxas.bin_image(img_ref, binning)
    if circle_mask_ratio < 1:
        img_ref = pyxas.circ_mask(img_ref, axis=0, ratio=circle_mask_ratio)
    if align_method == 1:
        img_ref = pyxas.move_3D_to_center(img_ref, circle_mask_ratio=circle_mask_ratio)
            
    fn_save = f'{file_path}/ali_recon_{scan_id}_bin_{binning}.h5'
    pyxas.save_hdf_file(fn_save, 'img', img_ref.astype(np.float32), 'scan_id', scan_id, 'X_eng', X_eng)

    time_start = time.time()
    num_file = len(files_recon)
    for i in range(num_file):
        fn = files_recon[i]
        align_3D_tomo_file_mpi_sub(fn, img_ref, file_path, binning, circle_mask_ratio, file_type, align_coarse, align_method, hdf_attr)   
        print(f'time elasped: {time.time() - time_start:05.1f}\n')
Exemplo n.º 4
0
def align_3D_tomo_file_mpi_sub(files_recon, img_ref, file_path='.', binning=1, circle_mask_ratio=0.9, file_type='.h5', align_coarse=1, align_method=1, hdf_attr='img'):
    '''
    align_method: 
        1:  old method
        2:  3D cross-correlation
    '''
    bin_info = ''
    fn = files_recon
    fn_short = fn.split('/')[-1]
    print(f'aligning {fn_short} ...')    
    img1, scan_id, X_eng = get_tomo_image(files_recon, file_type, hdf_attr)
    if circle_mask_ratio < 1:
        img1 = pyxas.circ_mask(img1, axis=0, ratio=circle_mask_ratio)
    if binning > 1:
        img1 = pyxas.bin_image(img1, binning)
        bin_info == f'_bin_{binning}'
    img_ali = align_3D_tomo_image(img1, img_ref, circle_mask_ratio, align_method, align_coarse)
    if X_eng == 0: # read tiff file
        fn_save = f'{file_path}/ali_{fn_short}'  
        print(f'saving aligned file: {fn_save.split("/")[-1]}\n')
        io.imsave(fn_save, img_ali.astype(np.float32))
    else:
        fn_save = f'{file_path}/ali_recon_{scan_id}{bin_info}.h5'  
        print(f'saving aligned file: {fn_save.split("/")[-1]}\n')
        pyxas.save_hdf_file(fn_save, 'img', img_ali.astype(np.float32), 'scan_id', scan_id, 'X_eng', X_eng)   
Exemplo n.º 5
0
def rgb(img_r, img_g=[], img_b=[], norm_flag=1, filter_size=3, circle_mask_ratio=0.8):
    '''
    compose RGB image
    '''
    from scipy.signal import medfilt
    assert len(img_r.shape) == 2, '2D image only'
    s = img_r.shape
    if len(img_g) == 0:
        img_g = np.zeros(s)
    if len(img_b) == 0:
        img_b = np.zeros(s)
    if filter_size >= 2:
        r = medfilt(img_r, int(filter_size))
        g = medfilt(img_g, int(filter_size))
        b = medfilt(img_b, int(filter_size))
    img_rgb = np.zeros([s[0], s[1], 3])
    img_rgb[:,:,0] = r
    img_rgb[:,:,1] = g
    img_rgb[:,:,2] = b
    img_rgb = pyxas.circ_mask(img_rgb, axis=2, ratio=circle_mask_ratio)
    if norm_flag:
        for i in range(3):
            t = np.max(img_rgb[:,:,i])
            img_rgb[:,:,i] /= t
    plt.figure()
    plt.imshow(img_rgb) 
Exemplo n.º 6
0
def move_3D_to_center(img, circle_mask_ratio=1):
    from scipy.ndimage import center_of_mass
    img0 = img
    s = np.array(img0.shape)/2
    if circle_mask_ratio < 1:
        img0 = pyxas.circ_mask(img0, axis=0, ratio=circle_mask_ratio, val=0)
    cm = np.array(center_of_mass(img0))
    shift_matrix = list(s - cm)
    img_cen = pyxas.shift(img, shift_matrix, order=0)
    return img_cen
Exemplo n.º 7
0
def align_3D_coarse(img_ref, img1, circle_mask_ratio=1, method='other'):
    '''
    method: 'center_mass' 
            else: aligning projection
    '''
    if method == 'center_mass':
        
        img0_crop = img_ref
        img1_crop = img1
        if circle_mask_ratio < 1:
            img0_crop = pyxas.circ_mask(img0_crop, axis=0, ratio=circle_mask_ratio, val=0)
            img1_crop = pyxas.circ_mask(img1_crop, axis=0, ratio=circle_mask_ratio, val=0)
        cm0 = np.array(center_of_mass(img0_crop))
        cm1 = np.array(center_of_mass(img1_crop))
        shift_matrix = cm1 - cm0
    else:    
        shift_matrix0 = pyxas.align_3D_coarse_axes(img_ref, img1, circle_mask_ratio=circle_mask_ratio, axes=0, shift_flag=0)
        shift_matrix1 = pyxas.align_3D_coarse_axes(img_ref, img1, circle_mask_ratio=circle_mask_ratio, axes=1, shift_flag=0)
        shift_matrix2 = pyxas.align_3D_coarse_axes(img_ref, img1, circle_mask_ratio=circle_mask_ratio, axes=2, shift_flag=0)
        shift_matrix = (shift_matrix0 + shift_matrix1 + shift_matrix2) / 2.0
    print(f'shifts: {shift_matrix}')
    img_ali = pyxas.shift(img1, shift_matrix, order=0)
    return img_ali, shift_matrix
Exemplo n.º 8
0
def align_3D_fine(img_ref, img1, circle_mask_ratio=1, sli_select=0, row_select=0, test_range=[-30, 30], sli_shift_guess=0, row_shift_guess=0, col_shift_guess=0, cen_mass_flag=0, ali_direction=[1,1,1]):

    '''
    ali_direction = [1,1,1] -> shift [sli, row, col] if it is "1"
    '''
    import time
    from scipy.ndimage import center_of_mass
    time_s = time.time()
    img_tmp = img_ref.copy()
    if circle_mask_ratio < 1:
        img_ref_crop = pyxas.circ_mask(img_tmp, axis=0, ratio=circle_mask_ratio, val=0)
    else:
        img_ref_crop = img_tmp.copy()
    img_tmp = img1.copy()
    if circle_mask_ratio < 1:
        img_raw_crop = pyxas.circ_mask(img_tmp, axis=0, ratio=circle_mask_ratio, val=0)
    else:
        img_raw_crop = img_tmp.copy()
    if sli_shift_guess != 0 or row_shift_guess != 0 or col_shift_guess != 0:
        img_raw_crop= shift(img_raw_crop, [sli_shift_guess, row_shift_guess, col_shift_guess], order=0)

    if sli_select == 0 or sli_select >= img_ref_crop.shape[0]:
        sli_select = int(img_ref_crop.shape[0]/2.0)

    if row_select == 0 or row_select >= img_ref_crop.shape[1]:
        row_select = int(img_ref_crop.shape[1]/2.0)

    if cen_mass_flag:
        prj_ref = np.sum(img_ref_crop, axis=1)
        sli_select = int(center_of_mass(prj_ref)[0])
        prj_ref = np.sum(img_ref_crop, axis=0)
        row_select = int(center_of_mass(prj_ref)[0])
    print(f'aligning using sli = {sli_select}, row = {row_select}')

    # align height first (sli)
    if ali_direction[0] == 1:
        print('aligning height ...')
        t1 = np.squeeze(img_ref_crop[:, row_select])
        t1 = t1/np.mean(t1)
        t1_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(t1)))
                
        rang = np.arange(test_range[0], test_range[1])
        corr_max = []
        for j in rang + row_select:
            t2 = np.squeeze(img_raw_crop[:, j])
            t2 = t2/np.mean(t2)
            t2_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(t2)))
            tmp = np.fft.ifft2(t1_fft * np.conj(t2_fft))  
            corr_max.append(np.max(tmp))      
        _, idmax = idxmax(np.abs(corr_max))
        # row_shft = -rang[int(idmax)]
        t2 = np.squeeze(img_raw_crop[:, row_select])
        sli_shft, cshft = pyxas.align_img(t1, t2, align_flag=0)
        img_raw_crop = shift(img_raw_crop, [sli_shft, 0, 0], order=1)

    # align row and col
    print('aligning row and col ...')
    t1 = img_ref_crop[sli_select]
    t1 = t1/np.mean(t1)
    # t1_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(t1)))
    t2 = img_raw_crop[sli_select]
    rshft, cshft = pyxas.align_img(t1, t2, align_flag=0)

    if ali_direction[1] == 0:
        rshft = 0
    if ali_direction[2] == 0:
        cshft = 0
    img_ali= shift(img_raw_crop, [0, rshft, cshft], order=1)

    shift_matrix = [sli_shft, rshft, cshft]
    print(f'sli_shift: {sli_shft: 04.1f},   rshft: {rshft: 04.1f},   cshft: {cshft: 04.1f}')
    print(f'time elapsed: {time.time() - time_s:4.2f} sec')
    return img_ali, shift_matrix
Exemplo n.º 9
0
def fit_xanes2D_align_tomo_proj(file_path='.', files_scan=[], binning=2, ref_index=-1, ref_rot_cen=-1, block_list=[], sli=[], ratio=0.8, file_prefix='fly', file_type='.h5'):
    '''
    Aligning the tomo-scan projections with assigned scan file, and generate 3D reconstruction.

    Inputs:
    -----------
    file_path: str
        Directory contains all "fly_scans"
    binning: int
        binning of reconstruction
    ref_index: int
        index of "fly_scans" which is assigned as reference projections
        this fly_scan should has has good reconstruction quality and fare full list of rotation angles
        if -1: use the last scan (sorted by alphabetic file name)
    ref_rot_cen: float
        rotation center for the referenced "fly_scan"
        if -1: find rotation center using cross-corelationship at angle-0 and angle-180  
    block_list: list
        indexes of bad projection
        e.g., list(np.arange(380,550)
    ratio: float: (0 < ratio < 1)
        faction of projection image to be use to align projections
        e.g., 0.6        
    file_prefix: str
        prefix of the "fly_scan"
        e.g., 'fly'
    file_type: str
        e.g. '.h5'
    
    Output:
    ----------------
    None, will save aligned 3D reconstruction in folder of "{file_path}/ali_recon"
    '''
    file_path = os.path.abspath(file_path)
    binning = int(binning)
    if len(files_scan) == 0:
        files_scan = pyxas.retrieve_file_type(file_path, file_prefix=file_prefix, file_type=file_type)
    num_files = len(files_scan)
    #    block_list=list(np.arange(380,550))    
    fn_ref = files_scan[ref_index]
    f_ref = h5py.File(fn_ref, 'r')
    img_ref, _, angle_ref = pyxas.retrieve_norm_tomo_image(fn_ref, index=ref_index, binning=binning)
    theta_ref = angle_ref / 180 * np.pi
    f_ref.close()
    if ref_rot_cen == -1:
        rot_cen = find_rot(fn_ref) / binning
    else:
        rot_cen = ref_rot_cen / binning
    sr = StackReg(StackReg.TRANSLATION)

    new_dir = f'{file_path}/ali_recon'
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

    for i in range(num_files):
        fn = files_scan[i]
        f1 = h5py.File(fn, 'r')
        scan_id = np.array(f1['scan_id'])
        angle1 = np.array(f1['angle'])
        theta1 = angle1 / 180 * np.pi
        f1.close()
        num_angle = len(angle1)
        s = time.time()
        img1_ali, eng1, rshft, cshft = pyxas.align_proj_sub(fn_ref, angle_ref, fn, angle1, binning, ratio=ratio, sli=sli, ali_method='stackreg')
        print(f'#{i}/{num_files}, time elapsed: {time.time()-s}\n')

        img1_ali = pyxas.norm_txm(img1_ali)
        rec = pyxas.recon_sub(img1_ali, theta1, rot_cen, block_list)
        rec = pyxas.circ_mask(rec, axis=0, ratio=ratio, val=0)
        print('saving files...\n')
        
        
        fn_save = f'{new_dir}/ali_recon_{scan_id}.h5'
        with h5py.File(fn_save, 'w') as hf:
            hf.create_dataset('img', data = rec.astype(np.float32))
            hf.create_dataset('scan_id', data = scan_id)
            hf.create_dataset('XEng', data = eng1)
            hf.create_dataset('angle', data = angle1)
        print(f'{fn_save} saved\n')