def get_transformation(Tstack, reference): """ Parameters reference: register frame according to 'previous' or 'first' frame ---------- Tstack : uint16 3D array [x,y,t] Returns ------- outStack : uint16 3D array [x,y,t] in time registered slice trans_matrix : calculated transformation matrix """ Tstack = np.transpose(Tstack, (2, 0, 1)) sr = StackReg(StackReg.RIGID_BODY) # trans_matrix= sr.register_stack(Tstack, reference='first') trans_matrix = sr.register_stack(Tstack, reference=reference) outStack = sr.transform_stack(Tstack) outStack = np.uint16(outStack) outStack = np.transpose(outStack, (1, 2, 0)) return outStack, trans_matrix
def sorting(top_path): sr = StackReg(StackReg.AFFINE) amp_folders = top_path.glob('**/Linear_Amp_FFT2X') #vol_ind allows all set of stacks to be saved in one foder for the TSA-style registration. #vol_ind=0 for folder in amp_folders: try: assert folder.is_dir() except AssertionError: continue print(folder) #make a folder named stacks to save the relevant set of stacks in. #for example if the stacks are in ../loc1/"timestamp"/oct/log_amp_fft2x #make ../loc1/stacks to save in. try: stack_pth = folder.parents[2] / 'Blink' stack_pth.mkdir() #however, to not overwrite the folder, skip it if it exists. except FileExistsError: pass #try except here because disp comp makes ALL folders and linear will be empty most of the time #this skips it. try: stack = im_open(folder) except ValueError: continue else: save_path = stack_pth / 'full_realstack.tif' avg_save_path = stack_pth / 'AVG_full_realstack.tif' aff = sr.register_transform_stack(stack, reference='first', n_frames=1, moving_average=5) avg_aff = np.mean(aff, axis=0) io.imsave(str(save_path), aff.astype('float32')) io.imsave(str(avg_save_path), avg_aff.astype('float32'))
def reg(fix, mov): #assumes (x,z,y) shape. x and y are same shape sr = StackReg(StackReg.TRANSLATION) #may need to update to rigid body #first do enface by rotating and averaging down stack #this will find x-y shift and just the y-shift will be taken rot_fix = np.rot90(fix, axes=(0, 2)) rot_mov = np.rot90(mov, axes=(0, 2)) y_shifts = reg_frames(rot_fix, rot_mov) #in the matrix the x shift is the y shift and the y shift is the z shift y_shifts[:, 1, -1] = 0 y_mov = sr.transform_stack(rot_mov, tmats=y_shifts) #now do x-z registration #rotate back to normal orientation temp_mov = np.rot90(y_mov, axes=(2, 0)) xz_shifts = reg_frames(fix, temp_mov) #do the transform reg_mov = sr.transform_stack(temp_mov, tmats=xz_shifts) return reg_mov
def run(self, ips, imgs, para=None): k = para['diag'] / np.sqrt((np.array(ips.img.shape)**2).sum()) size = tuple((np.array(ips.img.shape) * k).astype(np.int16)) IPy.info('down sample...') news = [] for img in imgs: if k != 0: img = tf.resize(img, size) if para['sigma'] != 0: img = ndimg.gaussian_filter(img, para['sigma']) news.append(img) IPy.info('register...') sr = StackReg(eval('StackReg.%s' % para['trans'])) sr.register_stack(np.array(news), reference=para['ref']) mats = sr._tmats.reshape((sr._tmats.shape[0], -1)) if k != 0: mats[:, [0, 1, 3, 4, 6, 7]] *= k if k != 0: mats[:, [0, 1, 2, 3, 4, 5]] /= k if para['tab']: IPy.show_table(pd.DataFrame( mats, columns=['A%d' % (i + 1) for i in range(mats.shape[1])]), title='%s-Tmats' % ips.title) if para['new'] == 'None': return IPy.info('transform...') for i in range(sr._tmats.shape[0]): tform = tf.ProjectiveTransform(matrix=sr._tmats[i]) img = tf.warp(imgs[i], tform) img -= imgs[i].min() img *= imgs[i].max() - imgs[i].min() if para['new'] == 'Inplace': imgs[i][:] = img if para['new'] == 'New': news[i] = img.astype(ips.img.dtype) self.progress(i, len(imgs)) if para['new'] == 'New': IPy.show_img(news, '%s-reg' % ips.title)
def align_img_stackreg(img_ref, img, align_flag=1, method='translation'): ''' :param img_ref: reference image :param img: image need to align :param align_flag: 1: will do alignment; 0: output shift list only :param method: 'translation': x, y shift 'rigid': translation + rotation 'scaled rotation': translation + rotation + scaling 'affine': translation + rotation + scaling + shearing :return: align_flag == 1: img_ali, row_shift, col_shift, sr (row_shift and col_shift only valid for translation) align_flag == 0: row_shift, col_shift, sr (row_shift and col_shift only valid for translation) ''' if method == 'translation': sr = StackReg(StackReg.TRANSLATION) elif method == 'rigid': sr = StackReg(StackReg.RIGID_BODY) elif method == 'scaled rotation': sr = StackReg(StackReg.SCALED_ROTATION) elif method == 'affine': sr = StackReg(StackReg.AFFINE) else: sr = [[1, 0, 0],[0, 1, 0], [0, 0, 1]] print('unrecognized align method, no aligning performed') tmat = sr.register(img_ref, img) row_shift = -tmat[1, 2] col_shift = -tmat[0, 2] if align_flag: img_ali = sr.transform(img) return img_ali, row_shift, col_shift, sr else: return row_shift, col_shift, sr
def correct_drift(reference, move): sr = StackReg(StackReg.RIGID_BODY) transformation_matrix = sr.register(reference, move) out_rot = sr.transform(move, transformation_matrix) return out_rot, transformation_matrix
def reg(stack): sr = StackReg(StackReg.TRANSLATION) t_mats = sr.register_stack(gauss(stack), reference='first', n_frames=1) t_mats[:, 0, -1] = 0 reg_stack = sr.transform_stack(stack, tmats=t_mats) return reg_stack
def align_simple(stack_img, transformation=StackReg.TRANSLATION, reference="previous"): sr = StackReg(transformation) tmats_ = sr.register_stack(stack_img, reference="previous") for i in range(10): out_stk = sr.transform_stack(stack_img, tmats=tmats_) return np.float32(out_stk)
def get_transformation(Tstack): # Tstack=np.transpose(Tstack,(2,0,1)) sr = StackReg(StackReg.RIGID_BODY) trans_matrix = sr.register_stack(Tstack, reference='first') # trans_matrix= sr.register_stack(Tstack, reference='previous') outStack = sr.transform_stack(Tstack) # outStack = np.transpose(outStack,(1,2,0)) outStack = np.uint16(outStack) return outStack, trans_matrix
def process(pth): print(f'\nProcessing {pth}') save_pth = pth / 'reg_stacks' #tmat_pth = pth / 'transformation_matrices' try: save_pth.mkdir() #tmat_pth.mkdir() except FileExistsError: ex('Save File for reg stacks or tmats already exists. Delete and re-run.' ) #tell pystack reg that we will use a translational transformation #there shouldn't be intra-volume rotation or shear (there might be for rod blink) sr = StackReg(StackReg.TRANSLATION) #register to the first slice and output the transfomation matrix without transforming #iterate through the files in the stacks folder files = pth.glob('*.tif') loop_starts = time.time() for i, file in enumerate(files): #start_time = time.time() print(f'Processing: {file}') #Intravolume registration of the fixed volume #load first stack and do a 3d gaussian blur with a sigma=1 #str needed for imread, otherwise only one frame is loaded fixed = io.imread(str(file)) #had gauss(), testing t_mats = sr.register_stack(gauss(fixed), reference='first', n_frames=1) #remove the x shift from all the matrices - horizontal movement isn't relevant here, #the volume should be acquired quickly enough that this isn't a problem. t_mats[:, 0, -1] = 0 fixed = sr.transform_stack(fixed, tmats=t_mats) #Using previous to see if I could get rid of wigge. Didn't seem to work. #t_mats = sr.register_stack(gauss(fixed), reference='previous') #t_mats[:,1,-1] = 0 #fixed = sr.transform_stack(fixed, tmats=t_mats) #save the register fixed volume in the parent directory for reference save_name = save_pth / f'reg_{file.name}' #io.imsave(arr=img_as_uint(fixed), fname=str(save_name)) io.imsave(arr=img_as_float32(fixed), fname=str(save_name)) #get fixed out of memory - may not be worth the time? print(f'Intravolume registration complete. File saved at {save_name}') #end_time = time.time() #print(f'{file} was processed in {(end_time - start_time):.2f}s. \ #\n{((end_time - loop_starts)/60):.2f} minutes have elapsed.') #de;ete emumerate #if i==4: #ex('auto break') end_time = time.time() print(f'Run took {(end_time-loop_starts)/60:.2f} minutes')
def test_registration_transformation(stack, stack_unregistered): sr = StackReg(stack["transformation"]) reference = ("previous" if not stack["transformation"] == StackReg.BILINEAR else "first") out = sr.register_transform_stack(stack_unregistered, reference=reference) np.testing.assert_allclose(to_uint16(out), to_uint16(stack["registered"]), rtol=1e-7, atol=1)
def pystackreg_work(self, uncorrected, option, worker_progress_signal, worker_max_progress_signal): self.end_set = False def progress_callback_pystackreg(current_iteration, end_iteration): if not self.end_set: worker_max_progress_signal.emit(end_iteration) self.end_set = True worker_progress_signal.emit(current_iteration) sr = StackReg(option) corrected = sr.register_transform_stack(uncorrected, reference='first', progress_callback = progress_callback_pystackreg) return {'corrected': corrected}
def align_3D_coarse_axes(img_ref, img1, circle_mask_ratio=0.6, axes=0, shift_flag=1): ''' Aligning the reconstructed tomo with assigned 3D reconstruction along given axis. It will project the 3D data along given axis to find the shifts Inputs: ----------- ref: 3D array data: 3D array need to align axis: int along which axis to project the 3D reconstruction to find image shifts 0, or 1, or 2 Output: ---------------- aligned tomo, shfit_matrix ''' img_tmp = img_ref.copy() if circle_mask_ratio < 1: img_ref_crop = pyxas.circ_mask(img_tmp, axis=0, ratio=circle_mask_ratio, val=0) else: img_ref_crop = img_tmp.copy() s = img_ref_crop.shape stack_range = [int(s[0]*(0.5-circle_mask_ratio/2)), int(s[0]*(0.5+circle_mask_ratio/2))] prj0 = np.sum(img_ref_crop[stack_range[0]:stack_range[1]], axis=axes) img_tmp = img1.copy() if circle_mask_ratio < 1: img_raw_crop = pyxas.circ_mask(img_tmp, axis=0, ratio=circle_mask_ratio, val=0) else: img_raw_crop = img_tmp.copy() prj1 = np.sum(img_raw_crop[stack_range[0]:stack_range[1]], axis=axes) sr = StackReg(StackReg.TRANSLATION) tmat = sr.register(prj0, prj1) r = -tmat[1, 2] c = -tmat[0, 2] if axes == 0: shift_matrix = np.array([0, r, c]) elif axes == 1: shift_matrix = np.array([r, 0, c]) elif axes == 2: shift_matrix = np.array([r, c, 0]) else: shift_matrix = np.array([0, 0, 0]) if shift_flag: img_ali = pyxas.shift(img1, shift_matrix, order=0) return img_ali, shift_matrix else: return shift_matrix
def alignChannels(self): h, w = self.stackedImage.shape[:2] gray = cv2.cvtColor(self.stackedImage, cv2.COLOR_BGR2GRAY) sr = StackReg(StackReg.TRANSLATION) for i, C in enumerate(cv2.split(self.stackedImage)): M = sr.register(C, gray) self.stackedImage[:, :, i] = cv2.warpPerspective( self.stackedImage[:, :, i], M, (w, h), borderMode=cv2.BORDER_REPLICATE) g.ui.childConn.send("Aligning RGB")
def rough_coregister(ref_volume, realign_volume, affine_transform, axis=2): """Rapid coregistration method using PyStackReg package INPUTS: ref_volume (np.array): Reference volume to realign to realign_volume (np.array): Volume to realign on the ref_volume affine_transform (np.array): Affine transform with translation and scale change information (doesn't support rotations) axis (int): slice axis to realign on length of this axis will change at the output OUTPUTS: ref_volume, rescaled_volume: arrays of realigned volumes with the same dimensions from the input volumes""" # Vertical shift in nbre of voxels (slice_delta = mm_delta /slice_height[in mm]) shift = -int(affine_transform[2, -1] / affine_transform[2, 2]) # Selects the shared slices if shift >= 0: realign_volume = realign_volume[..., shift:] ref_volume = ref_volume[..., :realign_volume.shape[-1]] else: ref_volume = ref_volume[..., -shift:realign_volume.shape[-1]] realign_volume = realign_volume[..., :ref_volume.shape[-1]] # Initializes the pystackreg affine transform (only slice wise wise translation) sr = StackReg(StackReg.AFFINE) aff_transform = np.array( [[ affine_transform[0, 0], 0, affine_transform[0, 0] * affine_transform[0, -1] ], [ 0, affine_transform[1, 1], affine_transform[1, 1] * affine_transform[1, -1] ], [0, 0, 1]]) slc = [slice(None)] * len(ref_volume) rescaled_volume = np.zeros_like(realign_volume) # Slice wise translation for i in range(realign_volume.shape[axis]): slc[axis] = slice(i, i + 1) print("\rAffine transformation on slice {}/{}".format( i + 1, realign_volume.shape[axis]), end=' ' * 5) rescaled_volume[tuple(slc)] = sr.transform_stack( realign_volume[tuple(slice(None), slc)], tmats=aff_transform[None, :]) print('') return ref_volume, rescaled_volume
def register_transform_save_FOV(directory, FOV, phase_channel, fluor_channels): img_directories = get_image_list(directory, FOV, phase_channel) images, trans_matrices = drift_correct_images(img_directories) Parallel(n_jobs=-1)( delayed(save_image)(directory, image) for directory, image in tqdm( zip(img_directories, images), total=len(img_directories), desc="Writing phase contrast FOV {} to disk".format(FOV), leave=False)) if len(fluor_channels) > 0: for fluor_channel in fluor_channels: img_directories = get_image_list(directory, FOV, fluor_channel) images = [] for x in range(len(img_directories)): images.append(io.imread(img_directories[x])) sr = StackReg(StackReg.RIGID_BODY) images = Parallel(n_jobs=-1)( delayed(sr.transform)(image, trans_matrix) for image, trans_matrix in tqdm( zip(images, trans_matrices), total=len(images), desc="Transforming FL channel {} in FOV {}".format( fluor_channel, FOV), leave=False)) for x in range(len(images)): images[x] = images[x].astype(np.uint16) Parallel(n_jobs=-1)( delayed(save_image)(directory, image) for directory, image in tqdm(zip(img_directories, images), total=len(images), desc="Writing FL channel {} images in FOV {} to disk". format(fluor_channel, FOV), leave=False))
def registration(hstack, bfchannel, nbchannels, file, output_dir): print("This is the bfchannel: " + bfchannel) hstackreg = [] sr = StackReg(StackReg.RIGID_BODY) # register each frame to the previous (already registered) one # this is what the original StackReg ImageJ plugin uses tmats = sr.register_stack(hstack[:,bfchannel,:,:], reference='previous') for channel in range(nbchannels): hstackreg[:,channel,:,:] = sr.transform_stack(hstack[:,channel,:,:]) # tmats contains the transformation matrices -> they can be saved # and loaded at another time np.save(os.path.join(output_dir,file,'tmats.npy'), tmats) return hstackreg
def register(registered_filename, registered_filepath, raw_filename, raw_filepath): image_stack = imageio.volread(raw_filepath + raw_filename) sr = StackReg(StackReg.RIGID_BODY) registered_data = sr.register_transform_stack(image_stack, reference='first', verbose=True) finished_file = np.uint8(registered_data) try: os.mkdir(registered_filepath) print('directory made, saving registered image:' + registered_filename) except: print('directory exists, saving registered image:' + registered_filename) imageio.volwrite(registered_filepath + '/' + registered_filename, finished_file, bigtiff=True)
def dewiggle(stk): print('Wiggles') sr = StackReg(StackReg.TRANSLATION) tmat = sr.register_stack(stk, reference='previous') #clear y shifts #tmat[:,1,-1] = 0 shifts = tmat[:,0,-1] frame_ind = np.arange(len(shifts)) #fit a line to the shifts to remove drift coeff = np.polyfit(frame_ind, shifts,3) drift_line = np.poly1d(coeff) wiggles = shifts-drift_line(frame_ind) tmat[:,0,-1] = wiggles wiggleless = sr.transform_stack(stk, tmats=tmat) return wiggleless
def reg_frames(fix, mov, sr=StackReg(StackReg.TRANSLATION)): tmats = np.zeros((mov.shape[0], 3, 3)) for i, fix_frame, mov_frame in zip(range(mov.shape[0]), fix, mov): #skip frame if empty with np.all(mov_frame==0)? tmat = sr.register(fix_frame, mov_frame) tmats[i] = tmat return tmats
def test_different_axis(stack, stack_unregistered, frame_axis): stack["registered"] = np.moveaxis(stack["registered"], 0, frame_axis) stack_unregistered = np.moveaxis(stack_unregistered, 0, frame_axis) sr = StackReg(stack["transformation"]) reference = ("previous" if not stack["transformation"] == StackReg.BILINEAR else "first") out = sr.register_transform_stack(stack_unregistered, reference=reference, axis=frame_axis) assert out.shape == stack["registered"].shape np.testing.assert_allclose( to_uint16(out), to_uint16(stack["registered"]), rtol=1e-7, atol=1, )
def find_rot(fn, thresh=0.05, method=1): from pystackreg import StackReg sr = StackReg(StackReg.TRANSLATION) f = h5py.File(fn, "r") ang = np.array(list(f["angle"])) img_bkg = np.squeeze(np.array(f["img_bkg_avg"])) if np.abs(ang[0]) < np.abs(ang[0] - 90): # e.g, rotate from 0 - 180 deg tmp = np.abs(ang - ang[0] - 180).argmin() else: # e.g.,rotate from -90 - 90 deg tmp = np.abs(ang - np.abs(ang[0])).argmin() img0 = np.array(list(f["img_tomo"][0])) img180_raw = np.array(list(f["img_tomo"][tmp])) f.close() img0 = img0 / img_bkg img180_raw = img180_raw / img_bkg img180 = img180_raw[:, ::-1] s = np.squeeze(img0.shape) im1 = -np.log(img0) im2 = -np.log(img180) im1[np.isnan(im1)] = 0 im2[np.isnan(im2)] = 0 im1[im1 < thresh] = 0 im2[im2 < thresh] = 0 im1 = medfilt2d(im1, 5) im2 = medfilt2d(im2, 5) im1_fft = np.fft.fft2(im1) im2_fft = np.fft.fft2(im2) results = dftregistration(im1_fft, im2_fft) row_shift = results[2] col_shift = results[3] rot_cen = s[1] / 2 + col_shift / 2 - 1 tmat = sr.register(im1, im2) rshft = -tmat[1, 2] cshft = -tmat[0, 2] rot_cen0 = s[1] / 2 + cshft / 2 - 1 print(f"rot_cen = {rot_cen} or {rot_cen0}") if method: return rot_cen else: return rot_cen0
def calculate_shifts_stackreg(stack): """ Calculate shifts using PyStackReg. Args ---------- stack : TomoStack object The image series to be aligned Returns ---------- shifts : NumPy array The X- and Y-shifts to be applied to each image """ sr = StackReg(StackReg.TRANSLATION) shifts = sr.register_stack(stack.data, reference='previous') shifts = -np.array([i[0:2, 2] for i in shifts]) return shifts
class ImageTransformOpticalFlow(): """ Class written to register stack of images for AET. Uses correlation based method to determine subpixel shift between predicted and measured images. Input parameters: - shape: shape of the image """ def __init__(self, shape, method="turboreg"): self.shape = shape self.x_lin, self.y_lin = np.meshgrid(np.arange(self.shape[1]), np.arange(self.shape[0])) self.xy_lin = np.concatenate((self.x_lin[np.newaxis,], self.y_lin[np.newaxis,])).astype('float32') self.sr = StackReg(StackReg.RIGID_BODY) def _estimate_single(self, predicted, measured): assert predicted.shape == self.shape assert measured.shape == self.shape aff_mat = self.sr.register(measured, predicted) tform = transform.AffineTransform(matrix = aff_mat) measured_warp = transform.warp(measured, tform.inverse, cval = 1.0, order = 5) transform_final = aff_mat.flatten()[0:6] return measured_warp, transform_final def estimate(self, predicted_stack, measured_stack): assert predicted_stack.shape == measured_stack.shape transform_vec_list = np.zeros((6,measured_stack.shape[2]), dtype="float32") #Change from torch array to numpy array flag_predicted_gpu = predicted_stack.is_cuda if flag_predicted_gpu: predicted_stack = predicted_stack.cpu() flag_measured_gpu = measured_stack.is_cuda if flag_measured_gpu: measured_stack = measured_stack.cpu() predicted_np = np.array(predicted_stack.detach()) measured_np = np.array(measured_stack.detach()) #For each image, estimate the affine transform error for img_idx in range(measured_np.shape[2]): measured_np[...,img_idx], transform_vec = self._estimate_single(predicted_np[...,img_idx], \ measured_np[...,img_idx]) transform_vec_list[...,img_idx] = transform_vec #Change data back to torch tensor format if flag_predicted_gpu: predicted_stack = predicted_stack.cuda() measured_np = torch.tensor(measured_np) if flag_measured_gpu: measured_stack = measured_stack.cuda() measured_np = measured_np.cuda() return measured_np, torch.tensor(transform_vec_list)
def align_stack_iter( stack, ref_stack_void=True, ref_stack=None, transformation=StackReg.TRANSLATION, method=("previous", "first"), max_iter=2, ): if ref_stack_void: ref_stack = stack for i in range(max_iter): sr = StackReg(transformation) for ii in range(len(method)): print(ii, method[ii]) tmats = sr.register_stack(ref_stack, reference=method[ii]) ref_stack = sr.transform_stack(ref_stack) stack = sr.transform_stack(stack, tmats=tmats) return np.float32(stack)
def get_stackreg_shifts(stack): """ Calculate alignment shifts for image stack using PyStackReg. Args ---------- stack : Hyperspy Signal2D Image stack to align. Returns ---------- sr_shifts : NumPy array Calculated alignment shifts. """ sr = StackReg(StackReg.TRANSLATION) sr_shifts = sr.register_stack(stack.data) sr_shifts = np.array( [sr_shifts[i][:-1, 2][::-1] for i in range(0, len(sr_shifts))]) return sr_shifts
def stack_reg_consecutive_frames(z_reg, top_Z): T = len(z_reg) Z, _, _, C = z_reg[0].shape for t in range(T - 1): print(f'XY registering t = {t}') ref = z_reg[0][top_Z, :, :, 1] target_img = z_reg[t + 1][top_Z, :, :, 1] sr = StackReg(StackReg.RIGID_BODY) reg_matrix = sr.register(ref, target_img) # # Use registration matrix on whole stack for z in range(Z): for c in range(3): z_reg[t + 1][z, :, :, c] = sr.transform(z_reg[t + 1][z, :, :, c]) return z_reg
def align(frames, reference, transformation, normalize, pa1, pa2, conn): ref = cv2.imread(g.tmp + "cache/" + reference + ".png", cv2.IMREAD_GRAYSCALE) if (pa1 != (0, 0) and pa2 != (0, 0)): # Processing Area ref = ref[pa1[1]:pa2[1], pa1[0]:pa2[0]] sr = StackReg(transformation) tmats = [] h, w = ref.shape[:2] scaleFactor = min(1.0, (100 / h)) ref = cv2.resize(ref, (int(w * scaleFactor), int(h * scaleFactor))) if (normalize): ref = cv2.normalize(ref, ref, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX) i = 0 for frame in frames: mov = cv2.imread(frame.replace("frames", "cache"), cv2.IMREAD_GRAYSCALE) if (pa1 != (0, 0) and pa2 != (0, 0)): # Processing Area mov = mov[pa1[1]:pa2[1], pa1[0]:pa2[0]] mov = cv2.resize(mov, (int(w * scaleFactor), int(h * scaleFactor))) if (normalize): mov = cv2.normalize(mov, mov, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX) M = sr.register(mov, ref) M[0][2] /= scaleFactor # X M[1][2] /= scaleFactor # Y tmats.append(M) conn.send("Aligning Frames") i += 1 return tmats
def align_stack(stack_img, ref_image_void=True, ref_stack=None, transformation=StackReg.TRANSLATION, reference="previous"): """Image registration flow using pystack reg""" # all the options are in one function sr = StackReg(transformation) if ref_image_void: tmats_ = sr.register_stack(stack_img, reference=reference) else: tmats_ = sr.register_stack(ref_stack, reference=reference) out_ref = sr.transform_stack(ref_stack) out_stk = sr.transform_stack(stack_img, tmats=tmats_) return np.float32(out_stk), tmats_
def reg(fix, fix_enface, mov): #assumes (x,z,y) shape. x and y are same shape sr = StackReg(StackReg.TRANSLATION) #may need to update to rigid body #first do enface by rotating and averaging down stack #this will find x-y shift and just the y-shift will be taken #rot_fix_mean = np.mean(np.rot90(fix, axes=(0,1)), axis=0) mov_enface = segment_nfl(mov) rot_mov = np.rot90(mov, axes=(0, 1)) #rot_mov_mean = np.mean(rot_mov, axis=0) #do the registration y_mat = sr.register(fix_enface, mov_enface) #copy matrix for all frames in moving stack #probably a more elegant way to do this with reeat or tile, but... enface_mat = np.zeros((mov.shape[1], 3, 3)) enface_mat[:, 0, 0] = 1 enface_mat[:, 1, 1] = 1 enface_mat[:, 2, 2] = 1 enface_mat[:, 1, -1] = y_mat[1, -1] y_mov = sr.transform_stack(rot_mov, tmats=enface_mat) #now do x-z registration #rotate back to normal orientation temp_mov = np.rot90(y_mov, axes=(1, 0)) #do the registration xz_mat = np.zeros((mov.shape[0], 3, 3)) #sr = StackReg(StackReg.AFFINE) for i, fix_frame, mov_frame in zip(range(mov.shape[0]), fix, temp_mov): #skip frame if empty with np.all(mov_frame==0)? tmat = sr.register(fix_frame, mov_frame) xz_mat[i] = tmat #do the transform reg_mov = sr.transform_stack(temp_mov, tmats=xz_mat) return reg_mov