def align_3D_coarse_axes(img_ref, img1, circle_mask_ratio=0.6, axes=0, shift_flag=1): ''' Aligning the reconstructed tomo with assigned 3D reconstruction along given axis. It will project the 3D data along given axis to find the shifts Inputs: ----------- ref: 3D array data: 3D array need to align axis: int along which axis to project the 3D reconstruction to find image shifts 0, or 1, or 2 Output: ---------------- aligned tomo, shfit_matrix ''' img_tmp = img_ref.copy() if circle_mask_ratio < 1: img_ref_crop = pyxas.circ_mask(img_tmp, axis=0, ratio=circle_mask_ratio, val=0) else: img_ref_crop = img_tmp.copy() s = img_ref_crop.shape stack_range = [int(s[0]*(0.5-circle_mask_ratio/2)), int(s[0]*(0.5+circle_mask_ratio/2))] prj0 = np.sum(img_ref_crop[stack_range[0]:stack_range[1]], axis=axes) img_tmp = img1.copy() if circle_mask_ratio < 1: img_raw_crop = pyxas.circ_mask(img_tmp, axis=0, ratio=circle_mask_ratio, val=0) else: img_raw_crop = img_tmp.copy() prj1 = np.sum(img_raw_crop[stack_range[0]:stack_range[1]], axis=axes) sr = StackReg(StackReg.TRANSLATION) tmat = sr.register(prj0, prj1) r = -tmat[1, 2] c = -tmat[0, 2] if axes == 0: shift_matrix = np.array([0, r, c]) elif axes == 1: shift_matrix = np.array([r, 0, c]) elif axes == 2: shift_matrix = np.array([r, c, 0]) else: shift_matrix = np.array([0, 0, 0]) if shift_flag: img_ali = pyxas.shift(img1, shift_matrix, order=0) return img_ali, shift_matrix else: return shift_matrix
def align_3D_tomo_file_mpi_specific(file_save_path, files_recon=[], files_ref='', binning=1, circle_mask_ratio=0.8, file_type='.h5', align_coarse=1, align_method=1, hdf_attr='img', num_cpu=4): ''' align_method: 1: old method 2: 3D cross-correlation ''' from multiprocessing import Pool, cpu_count from functools import partial num_cpu = min(round(cpu_count() * 0.8), num_cpu) print(f'align_3D_tomo using {num_cpu:2d} CPUs') # save ref image file_path = os.path.abspath(file_save_path) img_ref, scan_id, X_eng = get_tomo_image(files_ref, file_type, hdf_attr) if binning > 1: img_ref = pyxas.bin_image(img_ref, binning) if circle_mask_ratio < 1: img_ref = pyxas.circ_mask(img_ref, axis=0, ratio=circle_mask_ratio) if align_method == 1: img_ref = pyxas.move_3D_to_center(img_ref, circle_mask_ratio=circle_mask_ratio) fn_save = f'{file_save_path}/ali_recon_{scan_id}_bin_{binning}.h5' pyxas.save_hdf_file(fn_save, 'img', img_ref.astype(np.float32), 'scan_id', scan_id, 'X_eng', X_eng) # start align pool = Pool(num_cpu) pool.map(partial(align_3D_tomo_file_mpi_sub, img_ref=img_ref, file_path=file_path, binning=binning, circle_mask_ratio=circle_mask_ratio, file_type=file_type, align_coarse=align_coarse, align_method=align_method, hdf_attr=hdf_attr), files_recon) pool.close()
def align_3D_tomo_file_specific(file_save_path='.', files_recon=[], files_ref='', binning=1, circle_mask_ratio=0.9, file_type='.h5', align_coarse=1, align_method=1, hdf_attr='img'): ''' align_method: 1: old method 2: 3D cross-correlation ''' import time file_path = os.path.abspath(file_save_path) img_ref, scan_id, X_eng = get_tomo_image(files_ref, file_type, hdf_attr) if binning > 1: img_ref = pyxas.bin_image(img_ref, binning) if circle_mask_ratio < 1: img_ref = pyxas.circ_mask(img_ref, axis=0, ratio=circle_mask_ratio) if align_method == 1: img_ref = pyxas.move_3D_to_center(img_ref, circle_mask_ratio=circle_mask_ratio) fn_save = f'{file_path}/ali_recon_{scan_id}_bin_{binning}.h5' pyxas.save_hdf_file(fn_save, 'img', img_ref.astype(np.float32), 'scan_id', scan_id, 'X_eng', X_eng) time_start = time.time() num_file = len(files_recon) for i in range(num_file): fn = files_recon[i] align_3D_tomo_file_mpi_sub(fn, img_ref, file_path, binning, circle_mask_ratio, file_type, align_coarse, align_method, hdf_attr) print(f'time elasped: {time.time() - time_start:05.1f}\n')
def align_3D_tomo_file_mpi_sub(files_recon, img_ref, file_path='.', binning=1, circle_mask_ratio=0.9, file_type='.h5', align_coarse=1, align_method=1, hdf_attr='img'): ''' align_method: 1: old method 2: 3D cross-correlation ''' bin_info = '' fn = files_recon fn_short = fn.split('/')[-1] print(f'aligning {fn_short} ...') img1, scan_id, X_eng = get_tomo_image(files_recon, file_type, hdf_attr) if circle_mask_ratio < 1: img1 = pyxas.circ_mask(img1, axis=0, ratio=circle_mask_ratio) if binning > 1: img1 = pyxas.bin_image(img1, binning) bin_info == f'_bin_{binning}' img_ali = align_3D_tomo_image(img1, img_ref, circle_mask_ratio, align_method, align_coarse) if X_eng == 0: # read tiff file fn_save = f'{file_path}/ali_{fn_short}' print(f'saving aligned file: {fn_save.split("/")[-1]}\n') io.imsave(fn_save, img_ali.astype(np.float32)) else: fn_save = f'{file_path}/ali_recon_{scan_id}{bin_info}.h5' print(f'saving aligned file: {fn_save.split("/")[-1]}\n') pyxas.save_hdf_file(fn_save, 'img', img_ali.astype(np.float32), 'scan_id', scan_id, 'X_eng', X_eng)
def rgb(img_r, img_g=[], img_b=[], norm_flag=1, filter_size=3, circle_mask_ratio=0.8): ''' compose RGB image ''' from scipy.signal import medfilt assert len(img_r.shape) == 2, '2D image only' s = img_r.shape if len(img_g) == 0: img_g = np.zeros(s) if len(img_b) == 0: img_b = np.zeros(s) if filter_size >= 2: r = medfilt(img_r, int(filter_size)) g = medfilt(img_g, int(filter_size)) b = medfilt(img_b, int(filter_size)) img_rgb = np.zeros([s[0], s[1], 3]) img_rgb[:,:,0] = r img_rgb[:,:,1] = g img_rgb[:,:,2] = b img_rgb = pyxas.circ_mask(img_rgb, axis=2, ratio=circle_mask_ratio) if norm_flag: for i in range(3): t = np.max(img_rgb[:,:,i]) img_rgb[:,:,i] /= t plt.figure() plt.imshow(img_rgb)
def move_3D_to_center(img, circle_mask_ratio=1): from scipy.ndimage import center_of_mass img0 = img s = np.array(img0.shape)/2 if circle_mask_ratio < 1: img0 = pyxas.circ_mask(img0, axis=0, ratio=circle_mask_ratio, val=0) cm = np.array(center_of_mass(img0)) shift_matrix = list(s - cm) img_cen = pyxas.shift(img, shift_matrix, order=0) return img_cen
def align_3D_coarse(img_ref, img1, circle_mask_ratio=1, method='other'): ''' method: 'center_mass' else: aligning projection ''' if method == 'center_mass': img0_crop = img_ref img1_crop = img1 if circle_mask_ratio < 1: img0_crop = pyxas.circ_mask(img0_crop, axis=0, ratio=circle_mask_ratio, val=0) img1_crop = pyxas.circ_mask(img1_crop, axis=0, ratio=circle_mask_ratio, val=0) cm0 = np.array(center_of_mass(img0_crop)) cm1 = np.array(center_of_mass(img1_crop)) shift_matrix = cm1 - cm0 else: shift_matrix0 = pyxas.align_3D_coarse_axes(img_ref, img1, circle_mask_ratio=circle_mask_ratio, axes=0, shift_flag=0) shift_matrix1 = pyxas.align_3D_coarse_axes(img_ref, img1, circle_mask_ratio=circle_mask_ratio, axes=1, shift_flag=0) shift_matrix2 = pyxas.align_3D_coarse_axes(img_ref, img1, circle_mask_ratio=circle_mask_ratio, axes=2, shift_flag=0) shift_matrix = (shift_matrix0 + shift_matrix1 + shift_matrix2) / 2.0 print(f'shifts: {shift_matrix}') img_ali = pyxas.shift(img1, shift_matrix, order=0) return img_ali, shift_matrix
def align_3D_fine(img_ref, img1, circle_mask_ratio=1, sli_select=0, row_select=0, test_range=[-30, 30], sli_shift_guess=0, row_shift_guess=0, col_shift_guess=0, cen_mass_flag=0, ali_direction=[1,1,1]): ''' ali_direction = [1,1,1] -> shift [sli, row, col] if it is "1" ''' import time from scipy.ndimage import center_of_mass time_s = time.time() img_tmp = img_ref.copy() if circle_mask_ratio < 1: img_ref_crop = pyxas.circ_mask(img_tmp, axis=0, ratio=circle_mask_ratio, val=0) else: img_ref_crop = img_tmp.copy() img_tmp = img1.copy() if circle_mask_ratio < 1: img_raw_crop = pyxas.circ_mask(img_tmp, axis=0, ratio=circle_mask_ratio, val=0) else: img_raw_crop = img_tmp.copy() if sli_shift_guess != 0 or row_shift_guess != 0 or col_shift_guess != 0: img_raw_crop= shift(img_raw_crop, [sli_shift_guess, row_shift_guess, col_shift_guess], order=0) if sli_select == 0 or sli_select >= img_ref_crop.shape[0]: sli_select = int(img_ref_crop.shape[0]/2.0) if row_select == 0 or row_select >= img_ref_crop.shape[1]: row_select = int(img_ref_crop.shape[1]/2.0) if cen_mass_flag: prj_ref = np.sum(img_ref_crop, axis=1) sli_select = int(center_of_mass(prj_ref)[0]) prj_ref = np.sum(img_ref_crop, axis=0) row_select = int(center_of_mass(prj_ref)[0]) print(f'aligning using sli = {sli_select}, row = {row_select}') # align height first (sli) if ali_direction[0] == 1: print('aligning height ...') t1 = np.squeeze(img_ref_crop[:, row_select]) t1 = t1/np.mean(t1) t1_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(t1))) rang = np.arange(test_range[0], test_range[1]) corr_max = [] for j in rang + row_select: t2 = np.squeeze(img_raw_crop[:, j]) t2 = t2/np.mean(t2) t2_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(t2))) tmp = np.fft.ifft2(t1_fft * np.conj(t2_fft)) corr_max.append(np.max(tmp)) _, idmax = idxmax(np.abs(corr_max)) # row_shft = -rang[int(idmax)] t2 = np.squeeze(img_raw_crop[:, row_select]) sli_shft, cshft = pyxas.align_img(t1, t2, align_flag=0) img_raw_crop = shift(img_raw_crop, [sli_shft, 0, 0], order=1) # align row and col print('aligning row and col ...') t1 = img_ref_crop[sli_select] t1 = t1/np.mean(t1) # t1_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(t1))) t2 = img_raw_crop[sli_select] rshft, cshft = pyxas.align_img(t1, t2, align_flag=0) if ali_direction[1] == 0: rshft = 0 if ali_direction[2] == 0: cshft = 0 img_ali= shift(img_raw_crop, [0, rshft, cshft], order=1) shift_matrix = [sli_shft, rshft, cshft] print(f'sli_shift: {sli_shft: 04.1f}, rshft: {rshft: 04.1f}, cshft: {cshft: 04.1f}') print(f'time elapsed: {time.time() - time_s:4.2f} sec') return img_ali, shift_matrix
def fit_xanes2D_align_tomo_proj(file_path='.', files_scan=[], binning=2, ref_index=-1, ref_rot_cen=-1, block_list=[], sli=[], ratio=0.8, file_prefix='fly', file_type='.h5'): ''' Aligning the tomo-scan projections with assigned scan file, and generate 3D reconstruction. Inputs: ----------- file_path: str Directory contains all "fly_scans" binning: int binning of reconstruction ref_index: int index of "fly_scans" which is assigned as reference projections this fly_scan should has has good reconstruction quality and fare full list of rotation angles if -1: use the last scan (sorted by alphabetic file name) ref_rot_cen: float rotation center for the referenced "fly_scan" if -1: find rotation center using cross-corelationship at angle-0 and angle-180 block_list: list indexes of bad projection e.g., list(np.arange(380,550) ratio: float: (0 < ratio < 1) faction of projection image to be use to align projections e.g., 0.6 file_prefix: str prefix of the "fly_scan" e.g., 'fly' file_type: str e.g. '.h5' Output: ---------------- None, will save aligned 3D reconstruction in folder of "{file_path}/ali_recon" ''' file_path = os.path.abspath(file_path) binning = int(binning) if len(files_scan) == 0: files_scan = pyxas.retrieve_file_type(file_path, file_prefix=file_prefix, file_type=file_type) num_files = len(files_scan) # block_list=list(np.arange(380,550)) fn_ref = files_scan[ref_index] f_ref = h5py.File(fn_ref, 'r') img_ref, _, angle_ref = pyxas.retrieve_norm_tomo_image(fn_ref, index=ref_index, binning=binning) theta_ref = angle_ref / 180 * np.pi f_ref.close() if ref_rot_cen == -1: rot_cen = find_rot(fn_ref) / binning else: rot_cen = ref_rot_cen / binning sr = StackReg(StackReg.TRANSLATION) new_dir = f'{file_path}/ali_recon' if not os.path.exists(new_dir): os.makedirs(new_dir) for i in range(num_files): fn = files_scan[i] f1 = h5py.File(fn, 'r') scan_id = np.array(f1['scan_id']) angle1 = np.array(f1['angle']) theta1 = angle1 / 180 * np.pi f1.close() num_angle = len(angle1) s = time.time() img1_ali, eng1, rshft, cshft = pyxas.align_proj_sub(fn_ref, angle_ref, fn, angle1, binning, ratio=ratio, sli=sli, ali_method='stackreg') print(f'#{i}/{num_files}, time elapsed: {time.time()-s}\n') img1_ali = pyxas.norm_txm(img1_ali) rec = pyxas.recon_sub(img1_ali, theta1, rot_cen, block_list) rec = pyxas.circ_mask(rec, axis=0, ratio=ratio, val=0) print('saving files...\n') fn_save = f'{new_dir}/ali_recon_{scan_id}.h5' with h5py.File(fn_save, 'w') as hf: hf.create_dataset('img', data = rec.astype(np.float32)) hf.create_dataset('scan_id', data = scan_id) hf.create_dataset('XEng', data = eng1) hf.create_dataset('angle', data = angle1) print(f'{fn_save} saved\n')