def file_saving_child_process( data_buffers, buffer_shape, input_queue, output_queue, commands, ): try: import np_tif except ImportError: info("Failed to import np_tif.py; go get it from github:") info("https://github.com/AndrewGYork/tools/blob/master/np_tif.py") raise buffer_size = np.prod(buffer_shape) while True: if commands.poll(): cmd, args = commands.recv() info("Command received:" + cmd) if cmd == 'set_buffer_shape': buffer_shape = args['shape'] buffer_size = np.prod(buffer_shape) commands.send(buffer_shape) continue try: permission_slip = input_queue.get_nowait() except Q.Empty: sleep(0.001) #Probably doesn't sleep for 1 ms :( continue if permission_slip is None: #Poison pill! Pass it on, then quit. output_queue.put(permission_slip) break else: process_me = permission_slip['which_buffer'] info("start buffer %i" % (process_me)) time_received = clock() if 'file_info' in permission_slip: """ We only save the data buffer to a file if we have 'file information' in the permission slip. The value associated with the 'file_info' key is a dict of arguments to pass to np_tif.array_to_tif(), specifying things like the file name. """ info("saving buffer %i" % (process_me)) """ Save the buffer to disk as a TIF """ file_info = permission_slip['file_info'] with data_buffers[process_me].get_lock(): a = np.frombuffer( data_buffers[process_me].get_obj(), dtype=np.uint16)[:buffer_size].reshape(buffer_shape) np_tif.array_to_tif(a, **file_info) info("end buffer %i, elapsed time %06f" % (process_me, clock() - time_received)) output_queue.put(permission_slip) return None
def _calibrate_laser(self): if not os.path.exists('shaker_calibration.tif'): # Define simple mirror and laser voltages. The mirror sweeps # left-to-right, hopefully at constant speed. The modulator # turns on during (roughly) the middle of this sweep, at a # constant voltage. This voltage increases linearly over a # series of measurements. # The mirror voltage is easy: measurement_pixels = 80000 mirror_voltage = np.linspace(0, 2, 3*measurement_pixels) # The modulator voltage is slightly trickier: desired_num_illuminations = 45 num_snaps = max(1, int(np.round(desired_num_illuminations / self.idp.buffer_shape[0]))) num_illuminations = num_snaps * self.idp.buffer_shape[0] modulator_max_voltage = 0.6 modulator_voltage = np.linspace( 0, modulator_max_voltage, num_illuminations) illuminations = ( modulator_voltage.reshape(num_illuminations, 1) * np.ones(measurement_pixels).reshape(1, measurement_pixels) ).reshape(num_snaps, self.idp.buffer_shape[0], measurement_pixels) for s in range(num_snaps): self.snap_mirror_motion( mirror_voltage, measurement_start_pixel=(mirror_voltage.shape[0] - measurement_pixels) // 2, measurement_pixels=measurement_pixels, filename='calibration%i.tif'%s, illumination=illuminations[s, :, :]) data = [] for s in range(num_snaps): data.append(np_tif.tif_to_array('calibration%i.tif'%s ).astype(np.float32)) os.remove('calibration%i.tif'%s) data = np.concatenate(data, axis=0) # Lazy but effective variation = filters.median_filter(data.std(axis=0)**2 / data.mean(axis=0), size=3) mask = variation > 0.4 * variation.max() calibration = (data * mask.reshape((1,) + mask.shape) ).sum(axis=-1).sum(axis=-1) calibration -= calibration[0] np_tif.array_to_tif(np.array([modulator_voltage, calibration]), 'shaker_calibration.tif') calibration = np_tif.tif_to_array('shaker_calibration.tif' ).astype(np.float64) self.laser_calibration = { 'modulator_voltage': calibration[0, 0, :], 'illumination_brightness': filters.gaussian_filter( calibration[0, 1, :], sigma=0.5)} return None
def _save_snapped_image_as_tif(self, filename, verbose=True): # Don't call this method directly; it should be a side effect of # calling the 'snap' method. if np_tif is None: raise UserWarning( "If you want to save as TIF, get np_tif.py here:\n " + "https://github.com/AndrewGYork/tools/blob/master/np_tif.py\n" + "We failed to import np_tif.py.\n" + "This means we can't save camera images as TIFs.") assert filename.endswith('.tif') image = self._snapped_image_as_numpy_array() if verbose: print(" Saving a ", image.shape, ' ', image.dtype, " image as '", filename, "'...", sep='', end='') np_tif.array_to_tif(image, filename) if verbose: print(" done") return None
def record_iteration(self, save_tifs=True): self.saved_iterations.append(self.num_iterations) self.estimate_history.append(self.estimate.copy()) if save_tifs: eh = np.squeeze(np.concatenate(self.estimate_history, axis=0)) np_tif.array_to_tif(eh, self.output_prefix + 'estimate_history.tif') def f(x): if len(x.shape) == 2: x = x.reshape(1, x.shape[0], x.shape[1]) return np.log(1 + np.abs( np.fft.fftshift(np.fft.fftn(x, axes=(1, 2)), axes=(1, 2)))) np_tif.array_to_tif( f(eh - self.true_object), self.output_prefix + 'estimate_FT_error_history.tif') return None
def record_data(self): if hasattr(self, 'psfs'): psfs = np.squeeze(np.concatenate(self.psfs, axis=0)) np_tif.array_to_tif(psfs, self.output_prefix + 'psfs.tif') if hasattr(self, 'true_object'): np_tif.array_to_tif(self.true_object, self.output_prefix + 'object.tif') if hasattr(self, 'noiseless_measurement'): nm = np.squeeze(np.concatenate(self.noiseless_measurement, axis=0)) np_tif.array_to_tif( nm, self.output_prefix + 'noiseless_measurement.tif') if hasattr(self, 'noisy_measurement'): nm = np.squeeze(np.concatenate(self.noisy_measurement, axis=0)) np_tif.array_to_tif(nm, self.output_prefix + 'noisy_measurement.tif') return None
def stack_registration( s, align_to_this_slice=0, refinement='spike_interpolation', register_in_place=True, fourier_cutoff_radius=None, background_subtraction=None, debug=False, ): """Calculate shifts which would register the slices of a three-dimensional stack `s`, and optionally register the stack in-place. Axis 0 is the "z-axis", axis 1 is the "up-down" (Y) axis, and axis 2 is the "left-right" (X) axis. For each XY slice, we calculate the shift in the XY plane which would line that slice up with the slice specified by `align_to_this_slice`. If `align_to_this_slice` is a number, it indicates which slice of `s` to use as the reference slice. If `align_to_this_slice` is a numpy array, it is used as the reference slice, and must be the same shape as a 2D slice of `s`. `refinement` is one of `integer`, `spike_interpolation`, or `phase_fitting`, in order of increasing precision/slowness. I don't yet have any evidence that my implementation of phase fitting gives any improvement over (faster, simpler) spike interpolation, so caveat emptor. `register_in_place`: If `True`, modify the input stack `s` by shifting its slices to line up with the reference slice. `fourier_cutoff_radius`: Ignore the Fourier phases of spatial frequencies higher than this cutoff, since they're probably lousy due to aliasing and noise anyway. If `None`, attempt to estimate a resonable cutoff. 'background_subtraction': One of None, 'mean', 'min', or 'edge_mean'. Image registration is sensitive to edge effects. To combat this, we multiply the image by a real-space mask which goes to zero at the edges. For dim images on large DC backgrounds, the registration can end up mistaking this mask for an important image feature, distorting the registration. Sometimes it's helpful to subtract a background from the image before registration, to reduce this effect. 'mean' and 'min' subtract the mean and minimum of the stack 's', respectively, and 'edge_mean' subtracts the mean of the edge pixels. Use None (the default) for no background subtraction. """ assert len(s.shape) == 3 try: assert align_to_this_slice in range(s.shape[0]) align_to_this_slice = s[align_to_this_slice, :, :] except ValueError: align_to_this_slice = np.squeeze(align_to_this_slice) assert align_to_this_slice.shape == s.shape[-2:] assert refinement in ('integer', 'spike_interpolation', 'phase_fitting') if refinement == 'phase_fitting' and minimize is None: raise UserWarning("Failed to import scipy minimize; no phase fitting.") assert register_in_place in (True, False) # What background should we subtract from each slice of the stack? assert background_subtraction in (None, 'mean', 'min', 'edge_mean') if background_subtraction is None: bg = 0 elif background_subtraction is 'min': bg = s.min() elif background_subtraction is 'mean': bg = s.mean() elif background_subtraction is 'edge_mean': bg = np.mean( (s[:, 0, :].mean(), s[:, -1, :].mean(), s[:, :, 0].mean(), s[:, :, -1].mean())) if fourier_cutoff_radius is None: fourier_cutoff_radius = estimate_fourier_cutoff_radius(s, bg, debug) assert (0 < fourier_cutoff_radius <= 0.5) assert debug in (True, False) if debug and np_tif is None: raise UserWarning("Failed to import np_tif; no debug mode.") ## Multiply each slice of the stack by an XY mask that goes to zero ## at the edges, to prevent periodic boundary artifacts when we ## Fourier transform. mask_ud = np.sin(np.linspace(0, np.pi, s.shape[1])).reshape(s.shape[1], 1) mask_lr = np.sin(np.linspace(0, np.pi, s.shape[2])).reshape(1, s.shape[2]) masked_reference_slice = (align_to_this_slice - bg) * mask_ud * mask_lr ## We'll base our registration on the phase of the low spatial ## frequencies of the cross-power spectrum. We'll need the complex ## conjugate of the Fourier transform of the masked reference slice, ## and a mask in the Fourier domain to pick out the low spatial ## frequencies: ref_slice_ft_conj = np.conj(np.fft.rfftn(masked_reference_slice)) k_ud = np.fft.fftfreq(s.shape[1]).reshape(ref_slice_ft_conj.shape[0], 1) k_lr = np.fft.rfftfreq(s.shape[2]).reshape(1, ref_slice_ft_conj.shape[1]) fourier_mask = (k_ud**2 + k_lr**2) < (fourier_cutoff_radius)**2 ## Now we'll loop over each slice of the stack, calculate our ## registration shifts, and optionally apply the shifts to the ## original stack. registration_shifts = [] if debug: ## Save some intermediate data to help with debugging masked_stack = np.zeros_like(s) masked_stack_ft = np.zeros((s.shape[0], ) + ref_slice_ft_conj.shape, dtype=np.complex128) masked_stack_ft_vs_ref = np.zeros_like(masked_stack_ft) cross_power_spectra = np.zeros_like(masked_stack_ft) spikes = np.zeros(s.shape, dtype=np.float64) for which_slice in range(s.shape[0]): if debug: print("Calculating registration for slice", which_slice) ## Compute the cross-power spectrum of our slice, and mask out ## the high spatial frequencies. current_slice = (s[which_slice, :, :] - bg) * mask_ud * mask_lr current_slice_ft = np.fft.rfftn(current_slice) cross_power_spectrum = current_slice_ft * ref_slice_ft_conj cross_power_spectrum = (fourier_mask * cross_power_spectrum / np.abs(cross_power_spectrum)) ## Inverse transform to get a 'spike' in real space. The ## location of this spike gives the desired registration shift. ## Start by locating the spike to the nearest integer: spike = np.fft.irfftn(cross_power_spectrum, s=current_slice.shape) loc = np.array(np.unravel_index(np.argmax(spike), spike.shape)) if refinement in ('spike_interpolation', 'phase_fitting'): ## Use (very simple) three-point polynomial interpolation to ## refine the location of the peak of the spike: neighbors = np.array([-1, 0, 1]) ud_vals = spike[(loc[0] + neighbors) % spike.shape[0], loc[1]] lr_vals = spike[loc[0], (loc[1] + neighbors) % spike.shape[1]] lr_fit = np.poly1d(np.polyfit(neighbors, lr_vals, deg=2)) ud_fit = np.poly1d(np.polyfit(neighbors, ud_vals, deg=2)) lr_max_shift = -lr_fit[1] / (2 * lr_fit[2]) ud_max_shift = -ud_fit[1] / (2 * ud_fit[2]) loc = loc + (ud_max_shift, lr_max_shift) ## Convert our shift into a signed number near zero: loc = ((np.array(spike.shape) // 2 + loc) % np.array(spike.shape) - np.array(spike.shape) // 2) if refinement == 'phase_fitting': if debug: print("Phase fitting slice", which_slice, "...") ## (Attempt to) further refine our registration shift by ## fitting Fourier phases. I'm not sure this does any good, ## perhaps my implementation is lousy? def minimize_me(loc, cross_power_spectrum): disagreement = np.abs( expected_cross_power_spectrum(loc, k_ud, k_lr) - cross_power_spectrum)[fourier_mask].sum() if debug: print(" Shift:", loc, "Disagreement:", disagreement) return disagreement loc = minimize(minimize_me, x0=loc, args=(cross_power_spectrum, ), method='Nelder-Mead').x registration_shifts.append(loc) if debug: ## Save some intermediate data to help with debugging masked_stack[which_slice, :, :] = current_slice masked_stack_ft[which_slice, :, :] = (np.fft.fftshift( current_slice_ft, axes=0)) masked_stack_ft_vs_ref[which_slice, :, :] = (np.fft.fftshift( current_slice_ft * ref_slice_ft_conj, axes=0)) cross_power_spectra[which_slice, :, :] = (np.fft.fftshift( cross_power_spectrum, axes=0)) spikes[which_slice, :, :] = np.fft.fftshift(spike) if register_in_place: ## Modify the input stack in-place so it's registered. if refinement == 'integer': registration_type = 'nearest_integer' else: registration_type = 'fourier_interpolation' apply_registration_shifts(s, registration_shifts, registration_type=registration_type) if debug: np_tif.array_to_tif(masked_stack, 'DEBUG_masked_stack.tif') np_tif.array_to_tif(np.log(np.abs(masked_stack_ft)), 'DEBUG_masked_stack_FT_log_magnitudes.tif') np_tif.array_to_tif(np.angle(masked_stack_ft), 'DEBUG_masked_stack_FT_phases.tif') np_tif.array_to_tif(np.angle(masked_stack_ft_vs_ref), 'DEBUG_masked_stack_FT_phase_vs_ref.tif') np_tif.array_to_tif(np.angle(cross_power_spectra), 'DEBUG_cross_power_spectral_phases.tif') np_tif.array_to_tif(spikes, 'DEBUG_spikes.tif') if register_in_place: np_tif.array_to_tif(s, 'DEBUG_registered_stack.tif') return np.array(registration_shifts)
np.fft.rfft(stack[0, :, :]).shape) k_ud, k_lr = np.fft.fftfreq(stack.shape[1]), np.fft.rfftfreq( stack.shape[2]) k_ud, k_lr = k_ud.reshape(k_ud.size, 1), k_lr.reshape(1, k_lr.size) for s, (y, x) in enumerate(shifts): top = max(0, y) lef = max(0, x) bot = min(obj.shape[-2], obj.shape[-2] + y) rig = min(obj.shape[-1], obj.shape[-1] + x) shifted_obj.fill(0) shifted_obj[0, top:bot, lef:rig] = obj[0, top - y:bot - y, lef - x:rig - x] stack[s, :, :] = np.random.poisson( brightness * bucket(gaussian_filter(shifted_obj, blur), bucket_size)[0, crop:-crop, crop:-crop]) expected_phases[s, :, :] = np.angle( np.fft.fftshift(expected_cross_power_spectrum( (y / bucket_size[1], x / bucket_size[2]), k_ud, k_lr), axes=0)) np_tif.array_to_tif(expected_phases, 'DEBUG_expected_phase_vs_ref.tif') np_tif.array_to_tif(stack, 'DEBUG_stack.tif') print(" Done.") print("Registering test stack...") calculated_shifts = stack_registration(stack, refinement='spike_interpolation', debug=True) print(" Done.") for s, cs in zip(shifts, calculated_shifts): print('%0.2f (%i)' % (cs[0] * bucket_size[1], s[0]), '%0.2f (%i)' % (cs[1] * bucket_size[2], s[1]))
def stack_registration(s, align_to_this_slice=0, refinement='spike_interpolation', register_in_place=True, registered_stack_is_masked=False, fourier_mask_magnitude=0.15, debug=False): """Calculate shifts which would register the slices of a three-dimensional stack 's', and optionally register the stack in-place. Axis 0 is the "z-axis", axis 1 is the "up-down" (Y) axis, and axis 2 is the "left-right" (X) axis. For each XY slice, we calculate the shift in the XY plane which would line that slice up with the slice specified by 'align_to_this_slice'. 'refinement' is one of 'integer', 'spike_interpolation', or 'phase_fitting', in order of increasing precision/slowness. I don't yet have any evidence that my implementation of phase fitting gives any improvement (faster, simpler) simple spike interpolation, so caveat emptor. 'register_in_place': If 'True', modify the input stack 's' by shifting its slices to line up with the reference slice. 'registered_stack_is_masked': We mask each slice of the stack so that it goes to zero at the edges, which reduces Fourier artifacts and improves registration accuracy. If we're also modifying the input stack, we can save one Fourier transform per iteration if we're willing to substitute the 'masked' version of each slice for its original value. 'fourier_mask_magnitude': Ignore the Fourier phases of spatial frequencies above this cutoff, since they're probably lousy due to aliasing and noise anyway. """ assert len(s.shape) == 3 assert align_to_this_slice in range(s.shape[0]) assert refinement in ('integer', 'spike_interpolation', 'phase_fitting') if refinement == 'phase_fitting' and minimize is None: raise UserWarning("Failed to import scipy minimize; no phase fitting.") assert register_in_place in (True, False) assert registered_stack_is_masked in (True, False) assert 0 < fourier_mask_magnitude < 0.5 assert debug in (True, False) if debug and np_tif is None: raise UserWarning("Failed to import np_tif; no debug mode.") ## Multiply each slice of the stack by an XY mask that goes to zero ## at the edges, to prevent periodic boundary artifacts when we ## Fourier transform. mask_ud = np.sin(np.linspace(0, np.pi, s.shape[1])).reshape(s.shape[1], 1) mask_lr = np.sin(np.linspace(0, np.pi, s.shape[2])).reshape(1, s.shape[2]) masked_reference_slice = s[align_to_this_slice, :, :] * mask_ud * mask_lr ## We'll base our registration on the phase of the low spatial ## frequencies of the cross-power spectrum. We'll need the complex ## conjugate of the Fourier transform of the masked reference slice, ## and a mask in the Fourier domain to pick out the low spatial ## frequencies: ref_slice_ft_conj = np.conj(np.fft.rfftn(masked_reference_slice)) k_ud = np.fft.fftfreq(s.shape[1]).reshape(ref_slice_ft_conj.shape[0], 1) k_lr = np.fft.rfftfreq(s.shape[2]).reshape(1, ref_slice_ft_conj.shape[1]) fourier_mask = (k_ud**2 + k_lr**2) < (fourier_mask_magnitude)**2 ## We can also use these Fourier frequencies to define a convenience ## function that gives the expected spectral phase associated with ## an arbitrary subpixel shift: def expected_cross_power_spectrum(shift): shift_ud, shift_lr = shift return np.exp(-2j * np.pi * (k_ud * shift_ud + k_lr * shift_lr)) ## Now we'll loop over each slice of the stack, calculate our ## registration shifts, and optionally apply the shifts to the ## original stack. registration_shifts = [] if debug: ## Save some intermediate data to help with debugging masked_stack = np.zeros_like(s) masked_stack_ft = np.zeros((s.shape[0], ) + ref_slice_ft_conj.shape, dtype=np.complex128) cross_power_spectra = np.zeros( (s.shape[0], ) + ref_slice_ft_conj.shape, dtype=np.complex128) spikes = np.zeros_like(s) for which_slice in range(s.shape[0]): if debug: print("Slice", which_slice) if which_slice == align_to_this_slice and not debug: registration_shifts.append(np.array((0, 0))) if register_in_place and registered_stack_is_masked: s[which_slice, :, :] = masked_reference_slice continue ## Compute the cross-power spectrum of our slice, and mask out ## the high spatial frequencies. current_slice = s[which_slice, :, :] * mask_ud * mask_lr current_slice_ft = np.fft.rfftn(current_slice) cross_power_spectrum = current_slice_ft * ref_slice_ft_conj cross_power_spectrum = (fourier_mask * cross_power_spectrum / np.abs(cross_power_spectrum)) ## Inverse transform to get a 'spike' in real space. The ## location of this spike gives the desired registration shift. ## Start by locating the spike to the nearest integer: spike = np.fft.irfftn(cross_power_spectrum, s=current_slice.shape) loc = np.array(np.unravel_index(np.argmax(spike), spike.shape)) if refinement in ('spike_interpolation', 'phase_fitting'): ## Use (very simple) three-point polynomial interpolation to ## refine the location of the peak of the spike: neighbors = np.array([-1, 0, 1]) ud_vals = spike[(loc[0] + neighbors) % spike.shape[0], loc[1]] lr_vals = spike[loc[0], (loc[1] + neighbors) % spike.shape[1]] lr_fit = np.poly1d(np.polyfit(neighbors, lr_vals, deg=2)) ud_fit = np.poly1d(np.polyfit(neighbors, ud_vals, deg=2)) lr_max_shift = -lr_fit[1] / (2 * lr_fit[2]) ud_max_shift = -ud_fit[1] / (2 * ud_fit[2]) loc = loc + (ud_max_shift, lr_max_shift) ## Convert our shift into a signed number near zero: loc = ((np.array(spike.shape) // 2 + loc) % np.array(spike.shape) - np.array(spike.shape) // 2) if refinement == 'phase_fitting': if debug: print("Phase fitting slice", which_slice, "...") ## (Attempt to) further refine our registration shift by ## fitting Fourier phases. I'm not sure this does any good, ## perhaps my implementation is lousy? def minimize_me(loc, cross_power_spectrum): disagreement = np.abs( expected_cross_power_spectrum(loc) - cross_power_spectrum)[fourier_mask].sum() if debug: print(" Shift:", loc, "Disagreement:", disagreement) return disagreement loc = minimize(minimize_me, x0=loc, args=(cross_power_spectrum, ), method='Nelder-Mead').x registration_shifts.append(loc) if register_in_place: ## Modify the input stack in-place so it's registered. phase_correction = expected_cross_power_spectrum(loc) if registered_stack_is_masked: ## If we're willing to tolerate a "masked" result, we ## can save one Fourier transform: s[which_slice, :, :] = np.fft.irfftn( current_slice_ft / phase_correction, s=current_slice.shape).real else: ## Slower, but probably the right way to do it: shift_me = s[which_slice, :, :] if not shift_me.dtype == np.float64: shift_me = shift_me.astype(np.float64) s[which_slice, :, :] = np.fft.irfftn( np.fft.rfftn(shift_me) / phase_correction, s=current_slice.shape).real if debug: ## Save some intermediate data to help with debugging masked_stack[which_slice, :, :] = current_slice masked_stack_ft[which_slice, :, :] = (np.fft.fftshift( current_slice_ft, axes=0)) cross_power_spectra[which_slice, :, :] = (np.fft.fftshift( cross_power_spectrum * fourier_mask, axes=0)) spikes[which_slice, :, :] = np.fft.fftshift(spike) if debug: np_tif.array_to_tif(masked_stack, 'DEBUG_masked_stack.tif') np_tif.array_to_tif(np.log(np.abs(masked_stack_ft)), 'DEBUG_masked_stack_FT_log_magnitudes.tif') np_tif.array_to_tif(np.angle(masked_stack_ft), 'DEBUG_masked_stack_FT_phases.tif') np_tif.array_to_tif(np.angle(cross_power_spectra), 'DEBUG_cross_power_spectral_phases.tif') np_tif.array_to_tif(spikes, 'DEBUG_spikes.tif') if register_in_place: np_tif.array_to_tif(s, 'DEBUG_registered_stack.tif') return registration_shifts
for z_v in z_voltages: filename = ( 'fluorescence' + ## 'phase_angle_' + ang + '_green' + gr_pow + '_red' + rd_pow + z_v + '_up.tif') assert os.path.exists(filename) print("Loading", filename) data = np_tif.tif_to_array(filename).astype(np.float64) assert data.shape == (3*1, 128, 380) data = data.reshape(1, 3, 128, 380) # Stack to hyperstack data = data.mean(axis=0, keepdims=True) # Sum over reps data_list.append(data) print("Done loading.") data = np.concatenate(data_list) ## print('data shape is',data.shape) ## mean_subtracted = data - data.mean(axis=-3, keepdims=True) tif_shape = (data.shape[0] * data.shape[1], data.shape[2], data.shape[3]) print("Saving...") np_tif.array_to_tif(data.reshape(tif_shape),('dataset_green_all_powers_up.tif')) ## np_tif.array_to_tif(mean_subtracted.reshape(tif_shape), ## 'dataset_green'+gr_pow+'_mean_subtracted.tif') ## np_tif.array_to_tif(data_controlled,('data_controlled_green'+gr_pow+'.tif')) print("Done saving.")
def main(): num_angles = 32 angles = range(num_angles) num_reps_original = 1000 num_delays = 5 image_h = 128 image_w = 380 less_rows = 3 # define repetitions with dust particles crossing field of view or # extreme red power fluctuations gr_on_remove = { 1: [[432, 435], [694, 719], [860, 865]], 5: [[54, 57], [505, 508], [901, 906]], 10: [[214, 218], [421, 428]], 14: [[356, 360], [391, 394], [711, 713], [774, 802]], 18: [[208, 210], [661, 667], [989, 992]], 21: [[181, 187]], 22: [[63, 70], [328, 333], [440, 451], [544, 557], [897, 902], [935, 964]], 24: [[287, 306], [922, 924]], 25: [[69, 73], [639, 675], [880, 898]], 26: [[667, 677]], 27: [[9, 16], [557, 560], [664, 669]], 29: [[219, 221], [366, 369], [452, 458], [871, 875]], } gr_off_remove = { 1: [[706, 708]], 5: [[219, 222], [505, 510], [553, 557]], 7: [[158, 165], [213, 220], [310, 316], [493, 497], [950, 961]], 12: [[173, 176], [432, 434], [914, 922]], 13: [[494, 527]], 14: [[983, 987]], 15: [[451, 458], [698, 715], [873, 883]], 16: [[171, 178]], 17: [[100, 104], [323, 327]], 21: [[51, 56], [293, 295], [385, 390], [858, 864]], 22: [[106, 109], [279, 285], [565, 580], [829, 834], [904, 924]], } green_powers = [ '_0mW', '_1500mW', ] remove_dict_list = [gr_off_remove, gr_on_remove] # same order as green_powers red_powers = [ '_300mW', ] rd_pow = red_powers[0] data_mean = np.zeros(( len(green_powers), num_angles, num_delays, image_h - 2 * less_rows, image_w, )) for gr_pow_num, gr_pow in enumerate(green_powers): remove_range_dict = remove_dict_list[gr_pow_num] for ang in angles: filename = ('STE_' + 'darkfield_' + str(ang) + '_green' + gr_pow + '_red' + rd_pow + '_many_delays.tif') assert os.path.exists(filename) print("Loading", filename) data = np_tif.tif_to_array(filename).astype(np.float64) assert data.shape == (num_delays * num_reps_original, image_h, image_w) # Stack to hyperstack data = data.reshape(num_reps_original, num_delays, image_h, image_w) # crop data to remove over-exposed stuff data = data[:, :, less_rows:image_h - less_rows, :] # delete repetitions with dust particles crossing field of view if ang in remove_range_dict: remove_range_list = remove_range_dict[ang] for my_range in reversed(remove_range_list): first = my_range[0] last = my_range[1] delete_length = last - first + 1 data = np.delete(data, first + np.arange(delete_length), 0) print(ang, data.shape) # Get the average pixel brightness in the background region of the # phase contrast data. We'll use it to account for laser intensity # fluctuations avg_laser_brightness = get_bg_level(data.mean(axis=(0, 1))) # scale all images to have the same background brightness. This # amounts to a correction of roughly 1% or less local_laser_brightness = get_bg_level(data) data = data * (avg_laser_brightness / local_laser_brightness).reshape( data.shape[0], data.shape[1], 1, 1) # Average data over repetitions data = data.mean(axis=0) # Put data in file to be saved data_mean[gr_pow_num, ang, ...] = data # save data for a particular green power print("Saving...") tif_shape = (num_angles * num_delays, image_h - 2 * less_rows, image_w) np_tif.array_to_tif( data_mean[gr_pow_num, :, :, :, :].reshape(tif_shape), ('dataset_green' + gr_pow + '.tif')) print("Done saving.") return None
densities from simulated 3D SIM data. To save computation time, we'll ignore the y-dimension, and simulate x-z data. This script outputs a bunch of TIF files on disk. Use ImageJ to view them. """ # Define a 2D x-z test object print("Constructing test object") n_z, n_x = 60, 60 true_density = np.zeros((n_z, n_x)) true_density[n_z // 2 + 5, n_x // 2 + 1] = 1 true_density[n_z // 2 - 5, n_x // 2 - 1] = 1 true_density[n_z // 2, ::4] = 1 true_density[n_z // 2 - 10, ::5] = 1 np_tif.array_to_tif(true_density, '1_true_density.tif') # Define an x-z emission PSF print("Constructing emission PSF") na_limit = 0.25 * np.pi k_magnitude = 0.15 k_z = np.fft.fftfreq(n_z).reshape(n_z, 1) k_x = np.fft.fftfreq(n_x).reshape(1, n_x) k_abs = np.sqrt(k_x**2 + k_z**2) with np.errstate(divide='ignore', invalid='ignore'): # Ugly divide-by-zero code k_theta = np.nan_to_num(np.arccos(k_z / k_abs)) psf_field_ft = np.zeros((n_z, n_x), dtype=np.complex128) psf_field_ft[(k_theta < na_limit) & # Limited NA (np.abs(k_abs - k_magnitude) < 0.01) # Monochromatic ] = 1
## If no hyperstack file then process original data else: data_list = num_f*[None] for fn in range(num_f): input_filename_list[fn] = (input_filename %((fn-int(0.5*(num_f-1)))*mz_step)) print('Found input files, loading...', end='') data_list[fn] = np_tif.tif_to_array(input_filename_list[fn]) print('done') print('Creating np data array...', end='') data = np.asarray(data_list) data = data.reshape((num_f, num_slices) + data.shape[-2:]) print('done') print("Saving result...", end='') np_tif.array_to_tif( data.reshape(num_f*num_slices, data.shape[-2], data.shape[-1]), hyperstack_filename, slices=num_slices, channels=1, frames=num_f) print('done') print('tif shape (Microsope z, RR z, y, x) =', data.shape) ## Add white scale bar to all images for t in range(data.shape[0]): for z in range(data.shape[1]): image = data[t, z, :, :] image[50:60, 1800:1985] = 5000 ## Choose parameters for video current_frame = -1 xmargin = 0.15 ymargin = 0.15
def stack_rotational_registration( s, align_to_this_slice=0, refinement='spike_interpolation', register_in_place=True, fourier_cutoff_radius=None, fail_180_test='fix_but_print_warning', debug=False, ): """Calculate rotations which would rotationally register the slices of a three-dimensional stack `s`, and optionally rotate the stack in-place. Axis 0 is the "z-axis", axis 1 is the "up-down" (Y) axis, and axis 2 is the "left-right" (X) axis. For each XY slice, we calculate the rotation in the XY plane which would line that slice up with the slice specified by `align_to_this_slice`. If `align_to_this_slice` is a number, it indicates which slice of `s` to use as the reference slice. If `align_to_this_slice` is a numpy array, it is used as the reference slice, and must be the same shape as a 2D slice of `s`. `refinement` is pretty much just `spike_interpolation`, until I get around to adding other refinement techniques. `register_in_place`: If `True`, modify the input stack `s` by shifting its slices to line up with the reference slice. `fourier_cutoff_radius`: Ignore the Fourier amplitudes of spatial frequencies higher than this cutoff, since they're probably lousy due to aliasing and noise anyway. If `None`, attempt to estimate a resonable cutoff. 'fail_180_test': One of 'fix_but_print_warning', (the default), 'fix_silently', 'ignore_silently', or 'raise_exception'. The algorithm employed here has a fundamental ambiguity: due to symmetries in the Fourier domain under 180 degree rotations, the rotational registration won't return rotations bigger than +/- 90 degrees. We currently have a lousy correlation-based check to detect if slice(s) in your stack seem to need bigger rotations to align with the reference slice. What would you like to happen when this warning triggers? """ # TODO: take advantage of periodic/smooth decomposition assert len(s.shape) == 3 try: assert align_to_this_slice in range(s.shape[0]) align_to_this_slice = s[align_to_this_slice, :, :] except ValueError: # Maybe align_to_this_slice is a numpy array? align_to_this_slice = np.squeeze(align_to_this_slice) assert align_to_this_slice.shape == s.shape[-2:] ## Create a square-cropped view 'c' of the input stack 's'. We're ## going to use a circular mask anyway, and this saves us from ## having to worry about nonuniform k_x, k_y sampling in Fourier ## space: delta = s.shape[1] - s.shape[2] y_slice, x_slice = slice(s.shape[1]), slice(s.shape[2]) # Default: no crop if delta > 0: # Crop Y y_slice = slice(delta // 2, delta // 2 - delta) elif delta < 0: # Crop X x_slice = slice(-delta // 2, delta - delta // 2) c = s[:, y_slice, x_slice] align_to_this_slice = align_to_this_slice[y_slice, x_slice] assert c.shape[1] == c.shape[2] # Take this line out in a few months! assert refinement in ('spike_interpolation', ) assert register_in_place in (True, False) if fourier_cutoff_radius is None: fourier_cutoff_radius = estimate_fourier_cutoff_radius(c, debug=debug) assert (0 < fourier_cutoff_radius <= 0.5) assert fail_180_test in ('fix_but_print_warning', 'fix_silently', 'ignore_silently', 'raise_exception') assert debug in (True, False) if debug and np_tif is None: raise UserWarning("Failed to import np_tif; no debug mode.") if map_coordinates is None: raise UserWarning("Failed to import scipy map_coordinates;" + " no stack_rotational_registration.") ## We'll multiply each slice of the stack by a circular mask to ## prevent edge-effect artifacts when we Fourier transform: mask_ud = np.arange(-c.shape[1] / 2, c.shape[1] / 2).reshape(c.shape[1], 1) mask_lr = np.arange(-c.shape[2] / 2, c.shape[2] / 2).reshape(1, c.shape[2]) mask_r_sq = mask_ud**2 + mask_lr**2 max_r_sq = (c.shape[1] / 2)**2 mask = (mask_r_sq <= max_r_sq) * np.cos( (np.pi / 2) * (mask_r_sq / max_r_sq)) del mask_ud, mask_lr, mask_r_sq masked_reference_slice = align_to_this_slice * mask small_ref = bucket(masked_reference_slice, (4, 4)) ## We'll base our rotational registration on the logarithms of the ## amplitudes of the low spatial frequencies of the reference and ## target slices. We'll need the amplitudes of the Fourier transform ## of the masked reference slice: ref_slice_ft_log_amp = np.log( np.abs(np.fft.fftshift(np.fft.rfftn(masked_reference_slice), axes=0))) ## Transform the reference slice log amplitudes to polar ## coordinates. Note that we avoid some potential subtleties here ## because we've cropped to a square field of view (which was the ## right thing to do anyway): n_y, n_x = ref_slice_ft_log_amp.shape k_r = np.arange(1, 2 * fourier_cutoff_radius * n_x) k_theta = np.linspace(-np.pi / 2, np.pi / 2, np.ceil(2 * fourier_cutoff_radius * n_y)) k_theta_delta_degrees = (k_theta[1] - k_theta[0]) * 180 / np.pi k_r, k_theta = k_r.reshape(len(k_r), 1), k_theta.reshape(1, len(k_theta)) k_y = k_r * np.sin(k_theta) + n_y // 2 k_x = k_r * np.cos(k_theta) del k_r, k_theta, n_y, n_x polar_ref = map_coordinates(ref_slice_ft_log_amp, (k_y, k_x)) polar_ref_ft_conj = np.conj(np.fft.rfft(polar_ref, axis=1)) n_y, n_x = polar_ref_ft_conj.shape y, x = np.arange(n_y).reshape(n_y, 1), np.arange(n_x).reshape(1, n_x) polar_fourier_mask = (y > x) # Triangular half-space of good FT phases del n_y, n_x, y, x ## Now we'll loop over each slice of the stack, calculate our ## registration rotations, and optionally apply the rotations to the ## original stack. registration_rotations_degrees = [] if debug: ## Save some intermediate data to help with debugging n_z = c.shape[0] masked_stack = np.zeros_like(c) masked_stack_ft_log_amp = np.zeros((n_z, ) + ref_slice_ft_log_amp.shape) polar_stack = np.zeros((n_z, ) + polar_ref.shape) polar_stack_ft = np.zeros((n_z, ) + polar_ref_ft_conj.shape, dtype=np.complex128) cross_power_spectra = np.zeros_like(polar_stack_ft) spikes = np.zeros_like(polar_stack) for which_slice in range(c.shape[0]): if debug: print("Calculating rotational registration for slice", which_slice) ## Compute the polar transform of the log of the amplitude ## spectrum of our masked slice: current_slice = c[which_slice, :, :] * mask current_slice_ft_log_amp = np.log( np.abs(np.fft.fftshift(np.fft.rfftn(current_slice), axes=0))) polar_slice = map_coordinates(current_slice_ft_log_amp, (k_y, k_x)) ## Register the polar transform of the current slice against the ## polar transform of the reference slice, using a similar ## algorithm to the one in mr_stacky: polar_slice_ft = np.fft.rfft(polar_slice, axis=1) cross_power_spectrum = polar_slice_ft * polar_ref_ft_conj cross_power_spectrum = (polar_fourier_mask * cross_power_spectrum / np.abs(cross_power_spectrum)) ## Inverse transform to get a 'spike' in each row of fixed k_r; ## the location of this spike gives the desired rotation in ## theta pixels, which we can convert back to an angle. Start by ## locating the spike to the nearest integer: spike = np.fft.irfft(cross_power_spectrum, axis=1, n=polar_slice.shape[1]) spike_1d = spike.sum(axis=0) loc = np.argmax(spike_1d) if refinement is 'spike_interpolation': ## Use (very simple) three-point polynomial interpolation to ## refine the location of the peak of the spike: neighbors = np.array([-1, 0, 1]) lr_vals = spike_1d[(loc + neighbors) % spike.shape[1]] lr_fit = np.poly1d(np.polyfit(neighbors, lr_vals, deg=2)) lr_max_shift = -lr_fit[1] / (2 * lr_fit[2]) loc += lr_max_shift ## Convert our shift into a signed number near zero: loc = (loc + spike.shape[1] // 2) % spike.shape[1] - spike.shape[1] // 2 ## Convert this shift in "theta pixels" back to a rotation in degrees: angle = loc * k_theta_delta_degrees ## There's a fundamental 180-degree ambiguity in the rotation ## determined by this algorthim. We need to check if adding 180 ## degrees explains our data better. This is a half-assed fast ## check that just downsamples the bejesus out of the relevant ## slices and looks at cross-correlation. ## TODO: FIX THIS FOR REAL! ...or does it matter? Testing, for now. small_current_slice = bucket(current_slice, (4, 4)) small_cur_000 = rotate(small_current_slice, angle=angle, reshape=False) small_cur_180 = rotate(small_current_slice, angle=angle + 180, reshape=False) if (small_cur_180 * small_ref).sum() > (small_cur_000 * small_ref).sum(): if fail_180_test is 'fix_but_print_warning': angle += 180 print(" **Warning: potentially ambiguous rotation detected**") print(" Inspect the registration for rotations off by 180" + " degrees, at slice %i" % (which_slice)) elif fail_180_test is 'fix_silently': angle += 180 elif fail_180_test is 'ignore_silently': pass elif fail_180_test is 'raise_exception': raise UserWarning( "Potentially ambiguous rotation detected.\n" + "One of your slices needed more than 90 degrees rotation") ## Pencils down. registration_rotations_degrees.append(angle) if debug: ## Save some intermediate data to help with debugging masked_stack[which_slice, :, :] = current_slice masked_stack_ft_log_amp[which_slice, :, :] = ( current_slice_ft_log_amp) polar_stack[which_slice, :, :] = polar_slice polar_stack_ft[which_slice, :, :] = polar_slice_ft cross_power_spectra[which_slice, :, :] = cross_power_spectrum spikes[which_slice, :, :] = np.fft.fftshift(spike, axes=1) if register_in_place: ## Rotate the slices of the input stack in-place so it's ## rotationally registered: for which_slice in range(s.shape[0]): s[which_slice, :, :] = rotate( s[which_slice, :, :], angle=registration_rotations_degrees[which_slice], reshape=False) if debug: np_tif.array_to_tif(masked_stack, 'DEBUG_masked_stack.tif') np_tif.array_to_tif(masked_stack_ft_log_amp, 'DEBUG_masked_stack_FT_log_magnitudes.tif') np_tif.array_to_tif(polar_stack, 'DEBUG_polar_stack.tif') np_tif.array_to_tif(np.angle(polar_stack_ft), 'DEBUG_polar_stack_FT_phases.tif') np_tif.array_to_tif(np.log(np.abs(polar_stack_ft)), 'DEBUG_polar_stack_FT_log_magnitudes.tif') np_tif.array_to_tif(np.angle(cross_power_spectra), 'DEBUG_cross_power_spectral_phases.tif') np_tif.array_to_tif(spikes, 'DEBUG_spikes.tif') if register_in_place: np_tif.array_to_tif(s, 'DEBUG_registered_stack.tif') return np.array(registration_rotations_degrees)
def stack_registration( s, align_to_this_slice=0, refinement='spike_interpolation', register_in_place=True, fourier_cutoff_radius=None, debug=False, ): """Calculate shifts which would register the slices of a three-dimensional stack `s`, and optionally register the stack in-place. Axis 0 is the "z-axis", axis 1 is the "up-down" (Y) axis, and axis 2 is the "left-right" (X) axis. For each XY slice, we calculate the shift in the XY plane which would line that slice up with the slice specified by `align_to_this_slice`. If `align_to_this_slice` is an integer, it indicates which slice of `s` to use as the reference slice. If `align_to_this_slice` is a numpy array, it is used as the reference slice, and must be the same shape as a 2D slice of `s`. `refinement` is one of `integer`, (fast registration to the nearest pixel) or `spike_interpolation` (slower but hopefully more accurate sub-pixel registration). `register_in_place`: If `True`, modify the input stack `s` by shifting its slices to line up with the reference slice. `fourier_cutoff_radius`: Ignore the Fourier phases of spatial frequencies higher than this cutoff, since they're probably lousy due to aliasing and noise anyway. If `None`, attempt to estimate a resonable cutoff. `debug`: (Attempt to) output several TIF files that are often useful when stack_registration isn't working. This needs to import np_tif.py; you can get a copy of np_tif.py from https://github.com/AndrewGYork/tools/blob/master/np_tif.py Probably the best way to understand what these files are, and how to interpret them, is to read the code of the rest of this function. """ assert len(s.shape) == 3 try: assert align_to_this_slice in range(s.shape[0]) skip_aligning_this_slice = align_to_this_slice align_to_this_slice = s[align_to_this_slice, :, :] except ValueError: skip_aligning_this_slice = None align_to_this_slice = np.squeeze(align_to_this_slice) assert align_to_this_slice.shape == s.shape[-2:] assert refinement in ('integer', 'spike_interpolation') assert register_in_place in (True, False) if fourier_cutoff_radius is not None: assert (0 < fourier_cutoff_radius <= 0.5) assert debug in (True, False) if debug and np_tif is None: raise UserWarning("Failed to import np_tif; no debug mode.") ## We'll base our registration on the phase of the low spatial ## frequencies of the cross-power spectrum. We'll need the complex ## conjugate of the Fourier transform of the periodic version of the ## reference slice, and a mask in the Fourier domain to pick out the ## low spatial frequencies: ref_slice_rfft = np.fft.rfftn(align_to_this_slice) if fourier_cutoff_radius is None: # Attempt to estimate a sensible FT radius fourier_cutoff_radius = estimate_fourier_cutoff_radius( ref_slice_rfft, debug=debug, input_is_rfftd=True, shape=s.shape[1:]) ref_slice_ft_conj = np.conj(ref_slice_rfft - _smooth_rfft2(align_to_this_slice)) del ref_slice_rfft k_ud = np.fft.fftfreq(s.shape[1]).reshape(ref_slice_ft_conj.shape[0], 1) k_lr = np.fft.rfftfreq(s.shape[2]).reshape(1, ref_slice_ft_conj.shape[1]) fourier_mask = (k_ud**2 + k_lr**2) < (fourier_cutoff_radius)**2 ## Now we'll loop over each slice of the stack, calculate our ## registration shifts, and optionally apply the shifts to the ## original stack. registration_shifts = [] if debug: ## Save some intermediate data to help with debugging stack_ft = np.zeros((s.shape[0], ) + ref_slice_ft_conj.shape, dtype=np.complex128) periodic_stack = np.zeros_like(s) periodic_stack_ft = np.zeros((s.shape[0], ) + ref_slice_ft_conj.shape, dtype=np.complex128) periodic_stack_ft_vs_ref = np.zeros_like(periodic_stack_ft) cross_power_spectra = np.zeros_like(periodic_stack_ft) spikes = np.zeros(s.shape, dtype=np.float64) for which_slice in range(s.shape[0]): if which_slice == skip_aligning_this_slice and not debug: registration_shifts.append(np.zeros(2)) continue if debug: print("Calculating registration for slice", which_slice) ## Compute the cross-power spectrum of the periodic component of ## our slice with the periodic component of the reference slice, ## and mask out the high spatial frequencies. current_slice_ft = np.fft.rfftn(s[which_slice, :, :]) current_slice_ft_periodic = (current_slice_ft - _smooth_rfft2(s[which_slice, :, :])) cross_power_spectrum = current_slice_ft_periodic * ref_slice_ft_conj norm = np.abs(cross_power_spectrum) norm[norm == 0] = 1 cross_power_spectrum = (fourier_mask * cross_power_spectrum / norm) ## Inverse transform to get a 'spike' in real space. The ## location of this spike gives the desired registration shift. ## Start by locating the spike to the nearest integer: spike = np.fft.irfftn(cross_power_spectrum, s=s.shape[1:]) loc = np.array(np.unravel_index(np.argmax(spike), spike.shape)) if refinement == 'spike_interpolation': ## Use (very simple) three-point polynomial interpolation to ## refine the location of the peak of the spike: neighbors = np.array([-1, 0, 1]) ud_vals = spike[(loc[0] + neighbors) % spike.shape[0], loc[1]] lr_vals = spike[loc[0], (loc[1] + neighbors) % spike.shape[1]] lr_fit = np.poly1d(np.polyfit(neighbors, lr_vals, deg=2)) ud_fit = np.poly1d(np.polyfit(neighbors, ud_vals, deg=2)) lr_max_shift = -lr_fit[1] / (2 * lr_fit[2]) ud_max_shift = -ud_fit[1] / (2 * ud_fit[2]) loc = loc + (ud_max_shift, lr_max_shift) ## Convert our shift into a signed number near zero: loc = ((np.array(spike.shape) // 2 + loc) % np.array(spike.shape) - np.array(spike.shape) // 2) registration_shifts.append(loc) if register_in_place: ## Modify the input stack in-place so it's registered. apply_registration_shifts( s[which_slice:which_slice + 1, :, :], [loc], registration_type=('nearest_integer' if refinement == 'integer' else 'fourier_interpolation'), s_rfft=np.expand_dims(current_slice_ft, axis=0)) if debug: ## Save some intermediate data to help with debugging stack_ft[which_slice, :, :] = (np.fft.fftshift(current_slice_ft, axes=0)) periodic_stack[which_slice, :, :] = np.fft.irfftn( current_slice_ft_periodic, s=s.shape[1:]).real periodic_stack_ft[which_slice, :, :] = (np.fft.fftshift( current_slice_ft_periodic, axes=0)) periodic_stack_ft_vs_ref[which_slice, :, :] = (np.fft.fftshift( current_slice_ft_periodic * ref_slice_ft_conj, axes=0)) cross_power_spectra[which_slice, :, :] = (np.fft.fftshift( cross_power_spectrum, axes=0)) spikes[which_slice, :, :] = np.fft.fftshift(spike) if debug: np_tif.array_to_tif(periodic_stack, 'DEBUG_1_periodic_stack.tif') np_tif.array_to_tif(np.log(1 + np.abs(stack_ft)), 'DEBUG_2_stack_FT_log_magnitudes.tif') np_tif.array_to_tif(np.angle(stack_ft), 'DEBUG_3_stack_FT_phases.tif') np_tif.array_to_tif(np.log(1 + np.abs(periodic_stack_ft)), 'DEBUG_4_periodic_stack_FT_log_magnitudes.tif') np_tif.array_to_tif(np.angle(periodic_stack_ft), 'DEBUG_5_periodic_stack_FT_phases.tif') np_tif.array_to_tif(np.angle(periodic_stack_ft_vs_ref), 'DEBUG_6_periodic_stack_FT_phase_vs_ref.tif') np_tif.array_to_tif(np.angle(cross_power_spectra), 'DEBUG_7_cross_power_spectral_phases.tif') np_tif.array_to_tif(spikes, 'DEBUG_8_spikes.tif') if register_in_place: np_tif.array_to_tif(s, 'DEBUG_9_registered_stack.tif') return np.array(registration_shifts)
def generate_psfs( shape, #Desired pixel dimensions of the psfs excitation_brightness, #Peak brightness in saturation units depletion_brightness, #Peak brightness in saturation units blur_sigma, psf_type='point', output_dir=None, verbose=True, ): """ A utility function used by psf_report(). """ # Calculate gaussian point and line excitation patterns. Pixel # values encode fluence per pulse, in saturation units. if psf_type == 'point': excitation_psf_point = np.zeros(shape) excitation_psf_point[0, shape[1] // 2, shape[2] // 2] = 1 excitation_psf_point = gaussian_filter(excitation_psf_point, sigma=blur_sigma) excitation_psf_point *= excitation_brightness / excitation_psf_point.max( ) if psf_type == 'line': excitation_psf_line = np.zeros(shape) excitation_psf_line[0, :, shape[2] // 2] = 1 excitation_psf_line = gaussian_filter(excitation_psf_line, sigma=(0, 0, blur_sigma)) excitation_psf_line *= excitation_brightness / excitation_psf_line.max( ) # Calculate difference-of-gaussian point and line depletion # patterns. Pixel values encode fluence per pulse, in saturation # units. if psf_type == 'point': depletion_psf_inner = np.zeros(shape) depletion_psf_inner[0, shape[1] // 2, shape[2] // 2] = 1 depletion_psf_inner = gaussian_filter(depletion_psf_inner, sigma=blur_sigma) depletion_psf_outer = gaussian_filter(depletion_psf_inner, sigma=blur_sigma) depletion_psf_point = ( (depletion_psf_outer / depletion_psf_outer.max()) - (depletion_psf_inner / depletion_psf_inner.max())) depletion_psf_point *= depletion_brightness / depletion_psf_point.max() elif psf_type == 'line': depletion_psf_inner = np.zeros(shape) depletion_psf_inner[0, :, shape[2] // 2] = 1 depletion_psf_inner = gaussian_filter(depletion_psf_inner, sigma=(0, 0, blur_sigma)) depletion_psf_outer = gaussian_filter(depletion_psf_inner, sigma=(0, 0, blur_sigma)) depletion_psf_line = ( (depletion_psf_outer / depletion_psf_outer.max()) - (depletion_psf_inner / depletion_psf_inner.max())) depletion_psf_line *= depletion_brightness / depletion_psf_line.max() # Calculate "saturated" excitation/depletion patterns. Pixel values # encode probability per pulse that a ground-state molecule will # become excited (excitation) or an excited molecule will remain # excited (depletion). half_on_dose = 1 half_off_dose = 1 if psf_type == 'point': saturated_excitation_psf_point = 1 - 2**(-excitation_psf_point / half_on_dose) saturated_depletion_psf_point = 2**(-depletion_psf_point / half_off_dose) elif psf_type == 'line': saturated_excitation_psf_line = 1 - 2**(-excitation_psf_line / half_on_dose) saturated_depletion_psf_line = 2**(-depletion_psf_line / half_off_dose) # Calculate post-depletion excitation patterns. Pixel values encode # probability per pulse that a molecule will become excited, but not # be depleted. if psf_type == 'point': sted_psf_point = (saturated_excitation_psf_point * saturated_depletion_psf_point) elif psf_type == 'line': sted_psf_line = (saturated_excitation_psf_line * saturated_depletion_psf_line) # Calculate the "system" PSF, which can depend on both excitation # and emission. For descanned point-STED, the system PSF is the # (STED-shrunk) excitation PSF. For rescanned line-STED, the system # PSF also involves the emission PSF. if psf_type == 'point': descanned_point_sted_psf = sted_psf_point # Simple rename elif psf_type == 'line': emission_sigma = blur_sigma # Assume emission PSF same as excitation PSF line_sted_sigma, _ = get_width( sted_psf_line[0, sted_psf_line.shape[1] // 2, :]) line_rescan_ratio = (emission_sigma / line_sted_sigma)**2 + 1 if verbose: print(" Ideal line rescan ratio: %0.5f" % (line_rescan_ratio)) line_rescan_ratio = int(np.round(line_rescan_ratio)) if verbose: print(" Neareset integer:", line_rescan_ratio) point_obj = np.zeros(shape) point_obj[0, point_obj.shape[1] // 2, point_obj.shape[2] // 2] = 1 emission_psf = gaussian_filter(point_obj, sigma=emission_sigma) rescanned_signal_inst = np.zeros( (point_obj.shape[0], point_obj.shape[1], int(line_rescan_ratio * point_obj.shape[2]))) rescanned_signal_cumu = rescanned_signal_inst.copy() descanned_signal_cumu = np.zeros(shape) if verbose: print(" Calculating rescan psf...", end='') # I could use an analytical shortcut to calculate the rescan PSF # (see http://dx.doi.org/10.1364/BOE.4.002644 for details), but # for the sake of clarity, I explicitly simulate the rescan # imaging process: for scan_position in range(point_obj.shape[2]): # . Scan the excitation scanned_excitation = np.roll(sted_psf_line, scan_position - point_obj.shape[2] // 2, axis=2) # . Multiply the object by excitation to calculate the "glow": glow = point_obj * scanned_excitation # . Blur the glow by the emission PSF to calculate the image # on the detector: blurred_glow = gaussian_filter(glow, sigma=emission_sigma) # . Calculate the contribution to the descanned image (the # kind measured by Curdt or Schubert # http://www.ub.uni-heidelberg.de/archiv/14362 # http://www.ub.uni-heidelberg.de/archiv/15986 descanned_signal = np.roll(blurred_glow, point_obj.shape[2] // 2 - scan_position, axis=2) descanned_signal_cumu[:, :, scan_position] += descanned_signal.sum( axis=2) # . Roll the descanned image to the rescan position, to # produce the "instantaneous" rescanned image: rescanned_signal_inst.fill(0) rescanned_signal_inst[0, :, :point_obj.shape[2]] = descanned_signal rescanned_signal_inst = np.roll(rescanned_signal_inst, scan_position * line_rescan_ratio - point_obj.shape[2] // 2, axis=2) # . Add the "instantaneous" image to the "cumulative" image. rescanned_signal_cumu += rescanned_signal_inst if verbose: print(" ...done.") # . Bin the rescanned psf back to the same dimensions as the object: rescanned_line_sted_psf = np.roll( rescanned_signal_cumu, #Roll so center bin is centered on the image int(line_rescan_ratio // 2), axis=2).reshape( #Quick and dirty binning 1, rescanned_signal_cumu.shape[1], int(rescanned_signal_cumu.shape[2] / line_rescan_ratio), int(line_rescan_ratio)).sum(axis=3) descanned_line_sted_psf = descanned_signal_cumu # Simple rename if output_dir is not None: if not os.path.exists(output_dir): os.mkdir(output_dir) if psf_type == 'point': for array, filename in ((excitation_psf_point, 'excitation_psf_point.tif'), (depletion_psf_point, 'depletion_psf_point.tif'), (saturated_excitation_psf_point, 'excitation_fraction_psf_point.tif'), (saturated_depletion_psf_point, 'depletion_fraction_psf_point.tif'), (sted_psf_point, 'sted_psf_point.tif')): np_tif.array_to_tif(array, os.path.join(output_dir, filename)) elif psf_type == 'line': for array, filename in ((excitation_psf_line, 'excitation_psf_line.tif'), (depletion_psf_line, 'depletion_psf_line.tif'), (saturated_excitation_psf_line, 'excitation_fraction_psf_line.tif'), (saturated_depletion_psf_line, 'depletion_fraction_psf_line.tif'), (sted_psf_line, 'sted_psf_line.tif'), (emission_psf, 'emission_psf.tif'), (rescanned_signal_cumu, 'sted_psf_line_rescan_unscaled.tif'), (rescanned_line_sted_psf, 'sted_psf_line_rescan.tif'), (descanned_line_sted_psf, 'sted_psf_line_descan.tif')): np_tif.array_to_tif(array, os.path.join(output_dir, filename)) if psf_type == 'point': return { 'excitation': excitation_psf_point, 'depletion': depletion_psf_point, 'excitation_fraction': saturated_excitation_psf_point, 'depletion_fraction': saturated_depletion_psf_point, 'sted': sted_psf_point, 'descan_sted': descanned_point_sted_psf } elif psf_type == 'line': return { 'excitation': excitation_psf_line, 'depletion': depletion_psf_line, 'excitation_fraction': saturated_excitation_psf_line, 'depletion_fraction': saturated_depletion_psf_line, 'sted': sted_psf_line, 'descan_sted': descanned_line_sted_psf, 'rescan_sted': rescanned_line_sted_psf }
else: print('Loading original file...', end='', sep='') data = np_tif.tif_to_array(input_filename) print('done') data = data.reshape((num_tps, ) + data.shape[-2:]) print('tif shape (t, y, x) =', data.shape) print('Cropping...', end='', sep='') if left_crop or right_crop > 0: data = data[:, :, left_crop:-right_crop] if top_crop or bottom_crop > 0: data = data[:, top_crop:-bottom_crop, :] print('done') print("Saving result...", end='', sep='') np_tif.array_to_tif(data.reshape(num_tps, data.shape[-2], data.shape[-1]), cropped_filename, slices=1, channels=1, frames=num_tps) print('done') print('tif shape (t, y, x) =', data.shape) ## Choose parameters for video current_frame = -1 xmargin = 0.01 ymargin = 0.025 space = 0.175 img_size = 0.5 max_intensity = 12500 wlw = 2 # white line width half amplitude start_tp = 0 # removing uneventful begging stop_tp = 100 # remove uneventful end
# crop data (rectangle around bright fluorescent lobe) data_rep_cropped = data_rep[:, :, :, :, 49:97, 147:195] data_rep_signal = data_rep_cropped.mean(axis=5).mean(axis=4) # get data background data_rep_bg = data_rep[:, :, :, :, 20:30, 20:30] data_rep_bg = data_rep_bg.mean(axis=5).mean(axis=4) # compute repetition average of data after image registration data_avg = data_rep.mean(axis=0) ##representative_image_avg = (data_avg[0,-1,-1,16:128,108:238] - ## data_rep_bg[:,0,-1,-1].mean(axis=0)) rep_image_single_shot = (data_rep[0, -1, -1, -1, 16:128, 108:238] - data_rep_bg[0, -1, -1, -1]) data_avg_tif_shape = (data_avg.shape[0] * data_avg.shape[1] * data_avg.shape[2], data_avg.shape[3], data_avg.shape[4]) point_data_tif_shape = (data_rep_signal.shape[0] * data_rep_signal.shape[1], data_rep_signal.shape[2], data_rep_signal.shape[3]) print("Saving...") np_tif.array_to_tif(data_rep_signal.reshape(point_data_tif_shape), 'data_point_signal.tif') np_tif.array_to_tif(data_rep_bg.reshape(point_data_tif_shape), 'data_point_bg.tif') ##np_tif.array_to_tif( ## representative_image_avg,'representative_image_avg.tif') np_tif.array_to_tif(rep_image_single_shot, 'rep_image_single_shot.tif') print("... done.")
def main(): # each raw data stack has a full red and green power scan with red # varying slowly and green varying more quickly and green/red pulse # delay varying the quickest (5 delays, middle delay is 0 delay) num_reps = 200 # number power scans taken num_red_powers = 7 num_green_powers = 13 num_delays = 5 image_h = 128 image_w = 380 less_rows = 3 # top/bottom 3 rows may contain leakage from outside pixels top = less_rows bot = image_h - less_rows # assume no sample motion during a single power scan # allocate hyperstack to carry power/delay-averaged images for registration data_rep = np.zeros(( num_reps, image_h - less_rows * 2, image_w, ), dtype=np.float64) data_rep_bg = np.zeros(( num_reps, image_h - less_rows * 2, image_w, ), dtype=np.float64) # allocate array to carry a number corresponding to the average red # beam brightness for each red power red_avg_brightness = np.zeros((num_red_powers)) # populate hyperstack from data for rep_num in range(num_reps): filename = 'STE_darkfield_power_delay_scan_' + str(rep_num) + '.tif' print("Loading", filename) imported_power_scan = np_tif.tif_to_array(filename).astype( np.float64)[:, top:bot, :] red_avg_brightness += get_bg_level( imported_power_scan.reshape( num_red_powers, num_green_powers, num_delays, image_h - less_rows * 2, image_w).mean(axis=1).mean(axis=1)) / (2 * num_reps) data_rep[rep_num, :, :] = imported_power_scan.mean(axis=0) filename_bg = ('STE_darkfield_power_delay_scan_' + str(rep_num) + '_green_blocked.tif') print("Loading", filename_bg) imported_power_scan_bg = np_tif.tif_to_array(filename_bg).astype( np.float64)[:, top:bot, :] red_avg_brightness += get_bg_level( imported_power_scan_bg.reshape( num_red_powers, num_green_powers, num_delays, image_h - less_rows * 2, image_w).mean(axis=1).mean(axis=1)) / (2 * num_reps) data_rep_bg[rep_num, :, :] = imported_power_scan_bg.mean(axis=0) # reshape red_avg_brightness to add a dimension for multiplication # with a brightness array with dimensions num_red_powers X num_green # powers X num_delays red_avg_brightness = red_avg_brightness.reshape(num_red_powers, 1, 1) # pick image/slice for all stacks to align to representative_rep_num = 0 align_slice = data_rep[representative_rep_num, :, :] # save pre-registered average data (all powers for each rep) np_tif.array_to_tif(data_rep, 'dataset_not_registered_power_avg.tif') np_tif.array_to_tif(data_rep_bg, 'dataset_green_blocked_not_registered_power_avg.tif') # compute registration shifts print("Computing registration shifts...") shifts = stack_registration(data_rep, align_to_this_slice=align_slice, refinement='integer', register_in_place=True, background_subtraction='edge_mean') print("Computing registration shifts (no green) ...") shifts_bg = stack_registration(data_rep_bg, align_to_this_slice=align_slice, refinement='integer', register_in_place=True, background_subtraction='edge_mean') # save registered average data (all powers for each rep) and shifts np_tif.array_to_tif(data_rep, 'dataset_registered_power_avg.tif') np_tif.array_to_tif(data_rep_bg, 'dataset_green_blocked_registered_power_avg.tif') np_tif.array_to_tif(shifts, 'shifts.tif') np_tif.array_to_tif(shifts_bg, 'shifts_bg.tif') # now apply shifts to raw data and compute space-averaged signal # and representative images # define box around main lobe for computing space-averaged signal rect_top = 44 rect_bot = 102 rect_left = 172 rect_right = 228 # initialize hyperstacks for signal (with/without green light) print('Applying shifts to raw data...') signal = np.zeros(( num_reps, num_red_powers, num_green_powers, num_delays, ), dtype=np.float64) signal_bg = np.zeros(( num_reps, num_red_powers, num_green_powers, num_delays, ), dtype=np.float64) data_hyper_shape = (num_red_powers, num_green_powers, num_delays, image_h, image_w) # get representative image cropping coordinates rep_top = 22 rep_bot = 122 rep_left = 136 rep_right = 262 # initialize representative images (with/without green light) darkfield_image = np.zeros( ( #num_reps, rep_bot - rep_top, rep_right - rep_left, ), dtype=np.float64) STE_image = np.zeros( ( #num_reps, rep_bot - rep_top, rep_right - rep_left, ), dtype=np.float64) darkfield_image_bg = np.zeros( ( #num_reps, rep_bot - rep_top, rep_right - rep_left, ), dtype=np.float64) STE_image_bg = np.zeros( ( #num_reps, rep_bot - rep_top, rep_right - rep_left, ), dtype=np.float64) # finally apply shifts and compute output data for rep_num in range(num_reps): filename = 'STE_darkfield_power_delay_scan_' + str(rep_num) + '.tif' data = np_tif.tif_to_array(filename).astype(np.float64)[:, top:bot, :] filename_bg = ('STE_darkfield_power_delay_scan_' + str(rep_num) + '_green_blocked.tif') data_bg = np_tif.tif_to_array(filename_bg).astype( np.float64)[:, top:bot, :] print(filename) print(filename_bg) # apply registration shifts apply_registration_shifts(data, registration_shifts=[shifts[rep_num]] * data.shape[0], registration_type='nearest_integer', edges='sloppy') apply_registration_shifts(data_bg, registration_shifts=[shifts_bg[rep_num]] * data_bg.shape[0], registration_type='nearest_integer', edges='sloppy') # re-scale images to compensate for red beam brightness fluctuations # for regular data local_laser_brightness = get_bg_level( data.reshape(num_red_powers, num_green_powers, num_delays, data.shape[-2], data.shape[-1])) local_calibration_factor = red_avg_brightness / local_laser_brightness local_calibration_factor = local_calibration_factor.reshape( num_red_powers * num_green_powers * num_delays, 1, 1) data = data * local_calibration_factor # for green blocked data local_laser_brightness_bg = get_bg_level( data_bg.reshape(num_red_powers, num_green_powers, num_delays, data.shape[-2], data.shape[-1])) local_calibration_factor_bg = (red_avg_brightness / local_laser_brightness_bg) local_calibration_factor_bg = local_calibration_factor_bg.reshape( num_red_powers * num_green_powers * num_delays, 1, 1) data_bg = data_bg * local_calibration_factor_bg # draw rectangle around bright lobe and spatially average signal data_space_avg = data[:, rect_top:rect_bot, rect_left:rect_right].mean(axis=2).mean(axis=1) data_bg_space_avg = data_bg[:, rect_top:rect_bot, rect_left:rect_right].mean(axis=2).mean( axis=1) # reshape 1D signal and place in output file signal[rep_num, :, :, :] = data_space_avg.reshape( num_red_powers, num_green_powers, num_delays) signal_bg[rep_num, :, :, :] = data_bg_space_avg.reshape( num_red_powers, num_green_powers, num_delays) # capture average images for max red/green power image_green_power = num_green_powers - 1 image_red_power = num_red_powers - 1 STE_image += data[-3, # Zero delay, max red power, max green power rep_top:rep_bot, rep_left:rep_right] / num_reps darkfield_image += data[ -1, # max red-green delay (2.5 us), max red power, max green power rep_top:rep_bot, rep_left: rep_right] / num_reps / 2 # one of two maximum absolute red/green delay values darkfield_image += data[ -5, # min red-green delay (-2.5 us), max red power, max green power rep_top:rep_bot, rep_left: rep_right] / num_reps / 2 # one of two maximum absolute red/green delay values STE_image_bg += data_bg[ -3, # Zero delay, max red power, max green power rep_top:rep_bot, rep_left:rep_right] / num_reps darkfield_image_bg += data_bg[ -1, # max red-green delay (2.5 us), max red power, max green power rep_top:rep_bot, rep_left: rep_right] / num_reps / 2 # one of two maximum absolute red/green delay values darkfield_image_bg += data_bg[ -5, # min red-green delay (-2.5 us), max red power, max green power rep_top:rep_bot, rep_left: rep_right] / num_reps / 2 # one of two maximum absolute red/green delay values print('Done applying shifts') signal_tif_shape = (signal.shape[0] * signal.shape[1], signal.shape[2], signal.shape[3]) print("Saving...") np_tif.array_to_tif(signal.reshape(signal_tif_shape), 'signal_all_scaled.tif') np_tif.array_to_tif(signal_bg.reshape(signal_tif_shape), 'signal_green_blocked_all_scaled.tif') np_tif.array_to_tif(darkfield_image, 'darkfield_image_avg.tif') np_tif.array_to_tif(darkfield_image_bg, 'darkfield_image_bg_avg.tif') np_tif.array_to_tif(STE_image, 'STE_image_avg.tif') np_tif.array_to_tif(STE_image_bg, 'STE_image_bg_avg.tif') print("... done.") return None
expected_phases = np.zeros((len(shifts),) + np.fft.rfft(stack[0, :, :]).shape) k_ud, k_lr = np.fft.fftfreq(stack.shape[1]), np.fft.rfftfreq(stack.shape[2]) k_ud, k_lr = k_ud.reshape(k_ud.size, 1), k_lr.reshape(1, k_lr.size) for s, (y, x) in enumerate(shifts): top = max(0, y) lef = max(0, x) bot = min(obj.shape[-2], obj.shape[-2] + y) rig = min(obj.shape[-1], obj.shape[-1] + x) shifted_obj.fill(0) shifted_obj[0, top:bot, lef:rig] = obj[0, top-y:bot-y, lef-x:rig-x] stack[s, :, :] = np.random.poisson( brightness * bucket(gaussian_filter(shifted_obj, blur), bucket_size )[0, crop:-crop, crop:-crop]) expected_phases[s, :, :] = np.angle(np.fft.fftshift( expected_cross_power_spectrum((y/bucket_size[1], x/bucket_size[2]), k_ud, k_lr), axes=0)) np_tif.array_to_tif(expected_phases, 'DEBUG_expected_phase_vs_ref.tif') np_tif.array_to_tif(stack, 'DEBUG_stack.tif') print(" Done.") print("Registering test stack...") calculated_shifts = stack_registration( stack, refinement='spike_interpolation', debug=True) print(" Done.") for s, cs in zip(shifts, calculated_shifts): print('%0.2f (%i)'%(cs[0] * bucket_size[1], s[0]), '%0.2f (%i)'%(cs[1] * bucket_size[2], s[1]))
] for gr_pow in green_powers: data_list = [] for ang in angles: for rd_pow in red_powers: filename = ( 'STE_' + 'phase_angle_' + ang + '_green' + gr_pow + '_red' + rd_pow + '.tif') assert os.path.exists(filename) print("Loading", filename) data = np_tif.tif_to_array(filename).astype(np.float64) assert data.shape == (3, 128, 380) data = data.reshape(1, data.shape[0], data.shape[1], data.shape[2]) data_list.append(data) print("Done loading.") data = np.concatenate(data_list) print('data shape is',data.shape) tif_shape = (data.shape[0] * data.shape[1], data.shape[2], data.shape[3]) print("Saving...") np_tif.array_to_tif(data.reshape(tif_shape), ('dataset_green'+gr_pow+'_single_shot.tif')) print("Done saving.")
def stack_registration( s, align_to_this_slice=0, refinement='spike_interpolation', register_in_place=True, fourier_cutoff_radius=None, debug=False): """Calculate shifts which would register the slices of a three-dimensional stack `s`, and optionally register the stack in-place. Axis 0 is the "z-axis", axis 1 is the "up-down" (Y) axis, and axis 2 is the "left-right" (X) axis. For each XY slice, we calculate the shift in the XY plane which would line that slice up with the slice specified by `align_to_this_slice`. If `align_to_this_slice` is a number, it indicates which slice of `s` to use as the reference slice. If `align_to_this_slice` is a numpy array, it is used as the reference slice, and must be the same shape as a 2D slice of `s`. `refinement` is one of `integer`, `spike_interpolation`, or `phase_fitting`, in order of increasing precision/slowness. I don't yet have any evidence that my implementation of phase fitting gives any improvement over (faster, simpler) spike interpolation, so caveat emptor. `register_in_place`: If `True`, modify the input stack `s` by shifting its slices to line up with the reference slice. `fourier_cutoff_radius`: Ignore the Fourier phases of spatial frequencies higher than this cutoff, since they're probably lousy due to aliasing and noise anyway. If `None`, attempt to estimate a resonable cutoff. """ assert len(s.shape) == 3 try: assert align_to_this_slice in range(s.shape[0]) align_to_this_slice = s[align_to_this_slice, :, :] except ValueError: align_to_this_slice = np.squeeze(align_to_this_slice) assert align_to_this_slice.shape == s.shape[-2:] assert refinement in ('integer', 'spike_interpolation', 'phase_fitting') if refinement == 'phase_fitting' and minimize is None: raise UserWarning("Failed to import scipy minimize; no phase fitting.") assert register_in_place in (True, False) assert debug in (True, False) if fourier_cutoff_radius is None: fourier_cutoff_radius = estimate_fourier_cutoff_radius(s, debug) assert (0 < fourier_cutoff_radius <= 0.5) if debug and np_tif is None: raise UserWarning("Failed to import np_tif; no debug mode.") ## Multiply each slice of the stack by an XY mask that goes to zero ## at the edges, to prevent periodic boundary artifacts when we ## Fourier transform. mask_ud = np.sin(np.linspace(0, np.pi, s.shape[1])).reshape(s.shape[1], 1) mask_lr = np.sin(np.linspace(0, np.pi, s.shape[2])).reshape(1, s.shape[2]) masked_reference_slice = align_to_this_slice * mask_ud * mask_lr ## We'll base our registration on the phase of the low spatial ## frequencies of the cross-power spectrum. We'll need the complex ## conjugate of the Fourier transform of the masked reference slice, ## and a mask in the Fourier domain to pick out the low spatial ## frequencies: ref_slice_ft_conj = np.conj(np.fft.rfftn(masked_reference_slice)) k_ud = np.fft.fftfreq(s.shape[1]).reshape(ref_slice_ft_conj.shape[0], 1) k_lr = np.fft.rfftfreq(s.shape[2]).reshape(1, ref_slice_ft_conj.shape[1]) fourier_mask = (k_ud**2 + k_lr**2) < (fourier_cutoff_radius)**2 ## Now we'll loop over each slice of the stack, calculate our ## registration shifts, and optionally apply the shifts to the ## original stack. registration_shifts = [] if debug: ## Save some intermediate data to help with debugging masked_stack = np.zeros_like(s) masked_stack_ft = np.zeros( (s.shape[0],) + ref_slice_ft_conj.shape, dtype=np.complex128) masked_stack_ft_vs_ref = np.zeros_like(masked_stack_ft) cross_power_spectra = np.zeros_like(masked_stack_ft) spikes = np.zeros(s.shape, dtype=np.float64) for which_slice in range(s.shape[0]): if debug: print("Calculating registration for slice", which_slice) ## Compute the cross-power spectrum of our slice, and mask out ## the high spatial frequencies. current_slice = s[which_slice, :, :] * mask_ud * mask_lr current_slice_ft = np.fft.rfftn(current_slice) cross_power_spectrum = current_slice_ft * ref_slice_ft_conj cross_power_spectrum = (fourier_mask * cross_power_spectrum / np.abs(cross_power_spectrum)) ## Inverse transform to get a 'spike' in real space. The ## location of this spike gives the desired registration shift. ## Start by locating the spike to the nearest integer: spike = np.fft.irfftn(cross_power_spectrum, s=current_slice.shape) loc = np.array(np.unravel_index(np.argmax(spike), spike.shape)) if refinement in ('spike_interpolation', 'phase_fitting'): ## Use (very simple) three-point polynomial interpolation to ## refine the location of the peak of the spike: neighbors = np.array([-1, 0, 1]) ud_vals = spike[(loc[0] + neighbors) %spike.shape[0], loc[1]] lr_vals = spike[loc[0], (loc[1] + neighbors) %spike.shape[1]] lr_fit = np.poly1d(np.polyfit(neighbors, lr_vals, deg=2)) ud_fit = np.poly1d(np.polyfit(neighbors, ud_vals, deg=2)) lr_max_shift = -lr_fit[1] / (2 * lr_fit[2]) ud_max_shift = -ud_fit[1] / (2 * ud_fit[2]) loc = loc + (ud_max_shift, lr_max_shift) ## Convert our shift into a signed number near zero: loc = ((np.array(spike.shape)//2 + loc) % np.array(spike.shape) -np.array(spike.shape)//2) if refinement == 'phase_fitting': if debug: print("Phase fitting slice", which_slice, "...") ## (Attempt to) further refine our registration shift by ## fitting Fourier phases. I'm not sure this does any good, ## perhaps my implementation is lousy? def minimize_me(loc, cross_power_spectrum): disagreement = np.abs( expected_cross_power_spectrum(loc, k_ud, k_lr) - cross_power_spectrum )[fourier_mask].sum() if debug: print(" Shift:", loc, "Disagreement:", disagreement) return disagreement loc = minimize(minimize_me, x0=loc, args=(cross_power_spectrum,), method='Nelder-Mead').x registration_shifts.append(loc) if debug: ## Save some intermediate data to help with debugging masked_stack[which_slice, :, :] = current_slice masked_stack_ft[which_slice, :, :] = ( np.fft.fftshift(current_slice_ft, axes=0)) masked_stack_ft_vs_ref[which_slice, :, :] = ( np.fft.fftshift(current_slice_ft * ref_slice_ft_conj, axes=0)) cross_power_spectra[which_slice, :, :] = ( np.fft.fftshift(cross_power_spectrum, axes=0)) spikes[which_slice, :, :] = np.fft.fftshift(spike) if register_in_place: ## Modify the input stack in-place so it's registered. if refinement == 'integer': registration_type = 'nearest_integer' else: registration_type = 'fourier_interpolation' apply_registration_shifts(s, registration_shifts) if debug: np_tif.array_to_tif(masked_stack, 'DEBUG_masked_stack.tif') np_tif.array_to_tif(np.log(np.abs(masked_stack_ft)), 'DEBUG_masked_stack_FT_log_magnitudes.tif') np_tif.array_to_tif(np.angle(masked_stack_ft), 'DEBUG_masked_stack_FT_phases.tif') np_tif.array_to_tif(np.angle(masked_stack_ft_vs_ref), 'DEBUG_masked_stack_FT_phase_vs_ref.tif') np_tif.array_to_tif(np.angle(cross_power_spectra), 'DEBUG_cross_power_spectral_phases.tif') np_tif.array_to_tif(spikes, 'DEBUG_spikes.tif') if register_in_place: np_tif.array_to_tif(s, 'DEBUG_registered_stack.tif') return np.array(registration_shifts)
sr.apply_registration_shifts(data[which_t, :, :, :], registration_shifts) print('done') # Registraction will often leave black borders so it's best to crop # at the end to tidy up. print('Cropping...') left_crop = border right_crop = border top_crop = border bottom_crop = border data = data[:, :, top_crop:-bottom_crop, left_crop:-right_crop] print('done') print("Saving result...", end='') np_tif.array_to_tif(data.reshape(num_tps * num_slices, data.shape[-2], data.shape[-1]), registered_filename, slices=num_slices, channels=1, frames=num_tps) print('done') print('tif shape (t, z, y, z) =', data.shape) ## Choose parameters for video num_z_stack = 4 pause = 4 z_slow_down_factor = 1 z_scale = 10 current_frame = -1 current_tp = -1 xmargin = 0.15 ymargin = 0.725 space = 0.0275
[-5, 0], ] bucket_size = (1, 4, 4) crop = 8 print("Creating shifted stack from test object...") shifted_obj = np.zeros_like(obj) stack = np.zeros( (len(shifts) + 1, obj.shape[1] // bucket_size[1] - 2 * crop, obj.shape[2] // bucket_size[2] - 2 * crop)) stack[0, :, :] = bucket(obj, bucket_size)[0, crop:-crop, crop:-crop] for s, (y, x) in enumerate(shifts): top = max(0, y) lef = max(0, x) bot = min(obj.shape[-2], obj.shape[-2] + y) rig = min(obj.shape[-1], obj.shape[-1] + x) shifted_obj.fill(0) shifted_obj[0, top:bot, lef:rig] = obj[0, top - y:bot - y, lef - x:rig - x] stack[s + 1, :, :] = bucket(shifted_obj, bucket_size)[0, crop:-crop, crop:-crop] np_tif.array_to_tif(stack, 'DEBUG_stack.tif') print(" Done.") print("Registering test stack...") shifts = stack_registration(stack, refinement='spike_interpolation', debug=True) print(" Done.") for s in shifts: print(s) np_tif.array_to_tif(stack, 'DEBUG_stack_registered.tif')