def generate_unit_phase_shifts(shape, float_type=float): """ Computes the complex phase shift's angle due to a unit spatial shift. This is meant to be a helper function for ``register_mean_offsets``. It does this by computing a table of the angle of the phase of a unit shift in each dimension (with a factor of :math:`2\pi`). This allows arbitrary phase shifts to be made in each dimensions by multiplying these angles by the size of the shift and added to the existing angle to induce the proper phase shift in fourier space, which is equivalent to the spatial translation. Args: shape(tuple of ints): shape of the data to be shifted. float_type(real type): phase type (default numpy.float64) Returns: (numpy.ndarray): an array containing the angle of the complex phase shift to use for each dimension. Examples: >>> generate_unit_phase_shifts((2,4)) array([[[-0. , -0. , -0. , -0. ], [-3.14159265, -3.14159265, -3.14159265, -3.14159265]], <BLANKLINE> [[-0. , -1.57079633, -3.14159265, -4.71238898], [-0. , -1.57079633, -3.14159265, -4.71238898]]]) """ # Convert to `numpy`-based type if not done already. float_type = numpy.dtype(float_type).type # Must be of type float. assert issubclass(float_type, numpy.floating) assert numpy.dtype(float_type).itemsize >= 4 # Get the negative wave vector negative_wave_vector = numpy.asarray(shape, dtype=float_type) numpy.reciprocal(negative_wave_vector, out=negative_wave_vector) negative_wave_vector *= 2*numpy.pi numpy.negative(negative_wave_vector, out=negative_wave_vector) # Get the indices for each point in the selected space. indices = xnumpy.cartesian_product([numpy.arange(_) for _ in shape]) # Determine the phase offset for each point in space. complex_angle_unit_shift = indices * negative_wave_vector complex_angle_unit_shift = complex_angle_unit_shift.T.copy() complex_angle_unit_shift = complex_angle_unit_shift.reshape( (len(shape),) + shape ) return(complex_angle_unit_shift)
def generate_unit_phase_shifts(shape, float_type=float): """ Computes the complex phase shift's angle due to a unit spatial shift. This is meant to be a helper function for ``register_mean_offsets``. It does this by computing a table of the angle of the phase of a unit shift in each dimension (with a factor of :math:`2\pi`). This allows arbitrary phase shifts to be made in each dimensions by multiplying these angles by the size of the shift and added to the existing angle to induce the proper phase shift in fourier space, which is equivalent to the spatial translation. Args: shape(tuple of ints): shape of the data to be shifted. float_type(real type): phase type (default numpy.float64) Returns: (numpy.ndarray): an array containing the angle of the complex phase shift to use for each dimension. Examples: >>> generate_unit_phase_shifts((2,4)) array([[[-0. , -0. , -0. , -0. ], [-3.14159265, -3.14159265, -3.14159265, -3.14159265]], <BLANKLINE> [[-0. , -1.57079633, -3.14159265, -4.71238898], [-0. , -1.57079633, -3.14159265, -4.71238898]]]) """ # Convert to `numpy`-based type if not done already. float_type = numpy.dtype(float_type).type # Must be of type float. assert issubclass(float_type, numpy.floating) assert numpy.dtype(float_type).itemsize >= 4 # Get the negative wave vector negative_wave_vector = numpy.asarray(shape, dtype=float_type) numpy.reciprocal(negative_wave_vector, out=negative_wave_vector) negative_wave_vector *= 2 * numpy.pi numpy.negative(negative_wave_vector, out=negative_wave_vector) # Get the indices for each point in the selected space. indices = xnumpy.cartesian_product([numpy.arange(_) for _ in shape]) # Determine the phase offset for each point in space. complex_angle_unit_shift = indices * negative_wave_vector complex_angle_unit_shift = complex_angle_unit_shift.T.copy() complex_angle_unit_shift = complex_angle_unit_shift.reshape( (len(shape), ) + shape) return (complex_angle_unit_shift)
def register_mean_offsets(frames2reg, max_iters=-1, block_frame_length=-1, include_shift=False, to_truncate=False, float_type=numpy.dtype(float).type): """ This algorithm registers the given image stack against its mean projection. This is done by computing translations needed to put each frame in alignment. Then the translation is performed and new translations are computed. This is repeated until no further improvement can be made. The code for translations can be found in find_mean_offsets. Notes: Adapted from code provided by Wenzhi Sun with speed improvements provided by Uri Dubin. Args: frames2reg(numpy.ndarray): Image stack to register (time is the first dimension uses C-order tyx or tzyx). max_iters(int): Number of iterations to allow before forcing termination if stable point is not found yet. Set to -1 if no limit. (Default -1) block_frame_length(int): Number of frames to work with at a time. By default all. (Default -1) include_shift(bool): Whether to return the shifts used, as well. (Default False) to_truncate(bool): Whether to truncate the frames to remove all masked portions. (Default False) float_type(type): Type of float to use for calculation. (Default numpy.float64). Returns: (numpy.ndarray): an array containing the translations to apply to each frame. Examples: >>> a = numpy.zeros((5, 3, 4)); a[:,0] = 1; a[2,0] = 0; a[2,2] = 1 >>> a array([[[ 1., 1., 1., 1.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]], <BLANKLINE> [[ 1., 1., 1., 1.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]], <BLANKLINE> [[ 0., 0., 0., 0.], [ 0., 0., 0., 0.], [ 1., 1., 1., 1.]], <BLANKLINE> [[ 1., 1., 1., 1.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]], <BLANKLINE> [[ 1., 1., 1., 1.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]]]) >>> register_mean_offsets(a, include_shift=True) (masked_array(data = [[[1.0 1.0 1.0 1.0] [0.0 0.0 0.0 0.0] [0.0 0.0 0.0 0.0]] <BLANKLINE> [[1.0 1.0 1.0 1.0] [0.0 0.0 0.0 0.0] [0.0 0.0 0.0 0.0]] <BLANKLINE> [[-- -- -- --] [0.0 0.0 0.0 0.0] [0.0 0.0 0.0 0.0]] <BLANKLINE> [[1.0 1.0 1.0 1.0] [0.0 0.0 0.0 0.0] [0.0 0.0 0.0 0.0]] <BLANKLINE> [[1.0 1.0 1.0 1.0] [0.0 0.0 0.0 0.0] [0.0 0.0 0.0 0.0]]], mask = [[[False False False False] [False False False False] [False False False False]] <BLANKLINE> [[False False False False] [False False False False] [False False False False]] <BLANKLINE> [[ True True True True] [False False False False] [False False False False]] <BLANKLINE> [[False False False False] [False False False False] [False False False False]] <BLANKLINE> [[False False False False] [False False False False] [False False False False]]], fill_value = 0.0) , array([[0, 0], [0, 0], [1, 0], [0, 0], [0, 0]])) """ float_type = numpy.dtype(float_type).type # Must be of type float and must be at least 32-bit (smallest complex type # uses two 32-bit floats). assert issubclass(float_type, numpy.floating) assert numpy.dtype(float_type).itemsize >= 4 # Sadly, there is no easier way to map the two types; so, this is it. float_complex_mapping = { numpy.float32 : numpy.complex64, numpy.float64 : numpy.complex128, numpy.float128 : numpy.complex256 } complex_type = float_complex_mapping[float_type] if block_frame_length == -1: block_frame_length = len(frames2reg) tempdir_name = "" temporaries_filename = "" if isinstance(frames2reg, h5py.Dataset): tempdir_name, temporaries_filename = os.path.split( os.path.abspath(frames2reg.file.filename) ) temporaries_filename = os.path.splitext(temporaries_filename)[0] temporaries_filename += "_".join( [ frames2reg.name.replace("/", "_"), "temporaries.h5" ] ) temporaries_filename = os.path.join( tempdir_name, temporaries_filename ) elif (block_frame_length != len(frames2reg)): tempdir_name = tempfile.mkdtemp() temporaries_filename = os.path.join(tempdir_name, "temporaries.h5") frames2reg_fft = None space_shift = None this_space_shift = None if tempdir_name: temporaries_file = h5py.File(temporaries_filename, "w") frames2reg_fft = temporaries_file.create_dataset( "frames2reg_fft", shape=frames2reg.shape, dtype=complex_type ) space_shift = temporaries_file.create_dataset( "space_shift", shape=(len(frames2reg), len(frames2reg.shape)-1), dtype=int ) this_space_shift = temporaries_file.create_dataset( "this_space_shift", shape=space_shift.shape, dtype=space_shift.dtype ) else: frames2reg_fft = numpy.empty(frames2reg.shape, dtype=complex_type) space_shift = numpy.zeros( (len(frames2reg), len(frames2reg.shape)-1), dtype=int ) this_space_shift = numpy.empty_like(space_shift) for i, j in iters.lagged_generators_zipped( itertools.chain( xrange(0, len(frames2reg), block_frame_length), [len(frames2reg)] ) ): frames2reg_fft[i:j] = fft.fftn( frames2reg[i:j], axes=range(1, len(frames2reg.shape))) template_fft = numpy.empty(frames2reg.shape[1:], dtype=complex_type) negative_wave_vector = numpy.asarray(template_fft.shape, dtype=float_type) numpy.reciprocal(negative_wave_vector, out=negative_wave_vector) negative_wave_vector *= 2*numpy.pi numpy.negative(negative_wave_vector, out=negative_wave_vector) template_fft_indices = xnumpy.cartesian_product( [numpy.arange(_) for _ in template_fft.shape]) unit_space_shift_fft = template_fft_indices * negative_wave_vector unit_space_shift_fft = unit_space_shift_fft.T.copy() unit_space_shift_fft = unit_space_shift_fft.reshape( (template_fft.ndim,) + template_fft.shape) negative_wave_vector = None template_fft_indices = None # Repeat shift calculation until there is no further adjustment. num_iters = 0 squared_magnitude_delta_space_shift = 1.0 while (squared_magnitude_delta_space_shift != 0.0): squared_magnitude_delta_space_shift = 0.0 template_fft[:] = 0 for i, j in iters.lagged_generators_zipped( itertools.chain( xrange(0, len(frames2reg), block_frame_length), [len(frames2reg)] ) ): frames2reg_shifted_fft_ij = numpy.exp( 1j * numpy.tensordot( space_shift[i:j], unit_space_shift_fft, axes=[-1, 0] ) ) frames2reg_shifted_fft_ij *= frames2reg_fft[i:j] template_fft += numpy.sum(frames2reg_shifted_fft_ij, axis=0) template_fft /= len(frames2reg) for i, j in iters.lagged_generators_zipped( itertools.chain( xrange(0, len(frames2reg), block_frame_length), [len(frames2reg)] ) ): this_space_shift[i:j] = find_offsets( frames2reg_fft[i:j], template_fft ) # Remove global shifts. this_space_shift_mean = numpy.zeros( this_space_shift.shape[1:], dtype=this_space_shift.dtype) for i, j in iters.lagged_generators_zipped( itertools.chain( xrange(0, len(frames2reg), block_frame_length), [len(frames2reg)] ) ): this_space_shift_mean = this_space_shift[i:j].sum(axis=0) this_space_shift_mean = numpy.round( this_space_shift_mean.astype(float_type) / len(this_space_shift) ).astype(int) for i, j in iters.lagged_generators_zipped( itertools.chain( xrange(0, len(frames2reg), block_frame_length), [len(frames2reg)] ) ): this_space_shift[i:j] = xnumpy.find_relative_offsets( this_space_shift[i:j], this_space_shift_mean ) # Find the shortest roll possible (i.e. if it is going over halfway # switch direction so it will go less than half). # Note all indices by definition were positive semi-definite and upper # bounded by the shape. This change will make them bound by # the half shape, but with either sign. for i, j in iters.lagged_generators_zipped( itertools.chain( xrange(0, len(frames2reg), block_frame_length), [len(frames2reg)] ) ): this_space_shift[i:j] = xnumpy.find_shortest_wraparound( this_space_shift[i:j], frames2reg_fft.shape[1:] ) for i, j in iters.lagged_generators_zipped( itertools.chain( xrange(0, len(frames2reg), block_frame_length), [len(frames2reg)] ) ): delta_space_shift_ij = this_space_shift[i:j] - space_shift[i:j] squared_magnitude_delta_space_shift += numpy.dot( delta_space_shift_ij, delta_space_shift_ij.T ).sum() for i, j in iters.lagged_generators_zipped( itertools.chain( xrange(0, len(frames2reg), block_frame_length), [len(frames2reg)] ) ): space_shift[i:j] = this_space_shift[i:j] num_iters += 1 logger.info( "Completed iteration, %i, " % num_iters + "where the L_2 norm squared of the relative shift was, %f." % squared_magnitude_delta_space_shift ) if max_iters != -1: if num_iters >= max_iters: logger.info("Hit maximum number of iterations.") break reg_frames_shape = frames2reg.shape if to_truncate: space_shift_max = numpy.zeros( space_shift.shape[1:], dtype=space_shift.dtype ) space_shift_min = numpy.zeros( space_shift.shape[1:], dtype=space_shift.dtype ) for i, j in iters.lagged_generators_zipped( itertools.chain( xrange(0, len(frames2reg), block_frame_length), [len(frames2reg)] ) ): numpy.maximum( space_shift_max, space_shift[i:j].max(axis=0), out=space_shift_max ) numpy.minimum( space_shift_min, space_shift[i:j].min(axis=0), out=space_shift_min ) reg_frames_shape = numpy.asarray(reg_frames_shape) reg_frames_shape[1:] -= space_shift_max reg_frames_shape[1:] += space_shift_min reg_frames_shape = tuple(reg_frames_shape) space_shift_max = tuple(space_shift_max) space_shift_min = space_shift_min.astype(object) space_shift_min[space_shift_min == 0] = None space_shift_min = tuple(space_shift_min) reg_frames_slice = tuple( slice(_1, _2) for _1, _2 in itertools.izip( space_shift_max, space_shift_min ) ) # Adjust the registered frames using the translations found. # Mask rolled values. reg_frames = None if tempdir_name: if to_truncate: reg_frames = temporaries_file.create_dataset( "reg_frames", shape=reg_frames_shape, dtype=frames2reg.dtype, chunks=True ) else: reg_frames = temporaries_file.create_group("reg_frames") reg_frames = hdf5.serializers.HDF5MaskedDataset( reg_frames, shape=frames2reg.shape, dtype=frames2reg.dtype ) else: if to_truncate: reg_frames = numpy.empty(reg_frames_shape, dtype=frames2reg.dtype) else: reg_frames = numpy.ma.empty_like(frames2reg) reg_frames.mask = numpy.ma.getmaskarray(reg_frames) reg_frames.set_fill_value(reg_frames.dtype.type(0)) for i, j in iters.lagged_generators_zipped( itertools.chain( xrange(0, len(frames2reg), block_frame_length), [len(frames2reg)] ) ): for k in xrange(i, j): if to_truncate: reg_frames[k] = xnumpy.roll( frames2reg[k], space_shift[k])[reg_frames_slice] else: reg_frames[k] = xnumpy.roll( frames2reg[k], space_shift[k], to_mask=True) result = None results_filename = "" if tempdir_name: result = results_filename results_filename = os.path.join(tempdir_name, "results.h5") results_file = h5py.File(results_filename, "w") if to_truncate: temporaries_file.copy(reg_frames.name, results_file) else: temporaries_file.copy(reg_frames.group, results_file) if include_shift: temporaries_file.copy(space_shift, results_file) frames2reg_fft = None reg_frames = None space_shift = None this_space_shift = None temporaries_file.close() os.remove(temporaries_filename) temporaries_filename = "" result = results_filename else: result = reg_frames if include_shift: result = (reg_frames, space_shift) if tempdir_name: results_file.close() results_file = None return(result)