def file_saving_child_process(
    data_buffers,
    buffer_shape,
    input_queue,
    output_queue,
    commands,
):
    try:
        import np_tif
    except ImportError:
        info("Failed to import np_tif.py; go get it from github:")
        info("https://github.com/AndrewGYork/tools/blob/master/np_tif.py")
        raise
    buffer_size = np.prod(buffer_shape)
    while True:
        if commands.poll():
            cmd, args = commands.recv()
            info("Command received:" + cmd)
            if cmd == 'set_buffer_shape':
                buffer_shape = args['shape']
                buffer_size = np.prod(buffer_shape)
                commands.send(buffer_shape)
            continue
        try:
            permission_slip = input_queue.get_nowait()
        except Q.Empty:
            sleep(0.001)  #Probably doesn't sleep for 1 ms :(
            continue
        if permission_slip is None:  #Poison pill! Pass it on, then quit.
            output_queue.put(permission_slip)
            break
        else:
            process_me = permission_slip['which_buffer']
            info("start buffer %i" % (process_me))
            time_received = clock()
            if 'file_info' in permission_slip:
                """
                We only save the data buffer to a file if we have 'file
                information' in the permission slip. The value
                associated with the 'file_info' key is a dict of
                arguments to pass to np_tif.array_to_tif(), specifying
                things like the file name.
                """
                info("saving buffer %i" % (process_me))
                """
                Save the buffer to disk as a TIF
                """
                file_info = permission_slip['file_info']
                with data_buffers[process_me].get_lock():
                    a = np.frombuffer(
                        data_buffers[process_me].get_obj(),
                        dtype=np.uint16)[:buffer_size].reshape(buffer_shape)
                    np_tif.array_to_tif(a, **file_info)
            info("end buffer %i, elapsed time %06f" %
                 (process_me, clock() - time_received))
            output_queue.put(permission_slip)
    return None
Exemple #2
0
 def _calibrate_laser(self):
     if not os.path.exists('shaker_calibration.tif'):
         # Define simple mirror and laser voltages. The mirror sweeps
         # left-to-right, hopefully at constant speed. The modulator
         # turns on during (roughly) the middle of this sweep, at a
         # constant voltage. This voltage increases linearly over a
         # series of measurements.
         # The mirror voltage is easy:
         measurement_pixels = 80000
         mirror_voltage = np.linspace(0, 2, 3*measurement_pixels)
         # The modulator voltage is slightly trickier:
         desired_num_illuminations = 45
         num_snaps = max(1, int(np.round(desired_num_illuminations /
                                         self.idp.buffer_shape[0])))
         num_illuminations = num_snaps * self.idp.buffer_shape[0]
         modulator_max_voltage = 0.6
         modulator_voltage = np.linspace(
             0, modulator_max_voltage, num_illuminations)
         illuminations = (
             modulator_voltage.reshape(num_illuminations, 1) *
             np.ones(measurement_pixels).reshape(1, measurement_pixels)
             ).reshape(num_snaps,
                       self.idp.buffer_shape[0],
                       measurement_pixels)
         for s in range(num_snaps):
             self.snap_mirror_motion(
                 mirror_voltage,
                 measurement_start_pixel=(mirror_voltage.shape[0] -
                                          measurement_pixels) // 2,
                 measurement_pixels=measurement_pixels,
                 filename='calibration%i.tif'%s,
                 illumination=illuminations[s, :, :])
         data = []
         for s in range(num_snaps):
             data.append(np_tif.tif_to_array('calibration%i.tif'%s
                                             ).astype(np.float32))
             os.remove('calibration%i.tif'%s)
         data = np.concatenate(data, axis=0) # Lazy but effective
         variation = filters.median_filter(data.std(axis=0)**2 /
                                           data.mean(axis=0),
                                           size=3)
         mask = variation > 0.4 * variation.max()
         calibration = (data * mask.reshape((1,) + mask.shape)
                 ).sum(axis=-1).sum(axis=-1)
         calibration -= calibration[0]
         np_tif.array_to_tif(np.array([modulator_voltage,
                                       calibration]),
                             'shaker_calibration.tif')
     calibration = np_tif.tif_to_array('shaker_calibration.tif'
                                       ).astype(np.float64)
     self.laser_calibration = {
         'modulator_voltage': calibration[0, 0, :],
         'illumination_brightness': filters.gaussian_filter(
             calibration[0, 1, :], sigma=0.5)}
     return None
Exemple #3
0
 def _save_snapped_image_as_tif(self, filename, verbose=True):
     # Don't call this method directly; it should be a side effect of
     # calling the 'snap' method.
     if np_tif is None: 
         raise UserWarning(
             "If you want to save as TIF, get np_tif.py here:\n  " +
             "https://github.com/AndrewGYork/tools/blob/master/np_tif.py\n" +
             "We failed to import np_tif.py.\n" +
             "This means we can't save camera images as TIFs.")
     assert filename.endswith('.tif')
     image = self._snapped_image_as_numpy_array()
     if verbose:
         print(" Saving a ", image.shape, ' ', image.dtype, " image as '",
               filename, "'...", sep='', end='')
     np_tif.array_to_tif(image, filename)
     if verbose: print(" done")
     return None
Exemple #4
0
    def record_iteration(self, save_tifs=True):
        self.saved_iterations.append(self.num_iterations)
        self.estimate_history.append(self.estimate.copy())
        if save_tifs:
            eh = np.squeeze(np.concatenate(self.estimate_history, axis=0))
            np_tif.array_to_tif(eh,
                                self.output_prefix + 'estimate_history.tif')

            def f(x):
                if len(x.shape) == 2:
                    x = x.reshape(1, x.shape[0], x.shape[1])
                return np.log(1 + np.abs(
                    np.fft.fftshift(np.fft.fftn(x, axes=(1, 2)), axes=(1, 2))))

            np_tif.array_to_tif(
                f(eh - self.true_object),
                self.output_prefix + 'estimate_FT_error_history.tif')
        return None
Exemple #5
0
 def record_data(self):
     if hasattr(self, 'psfs'):
         psfs = np.squeeze(np.concatenate(self.psfs, axis=0))
         np_tif.array_to_tif(psfs, self.output_prefix + 'psfs.tif')
     if hasattr(self, 'true_object'):
         np_tif.array_to_tif(self.true_object,
                             self.output_prefix + 'object.tif')
     if hasattr(self, 'noiseless_measurement'):
         nm = np.squeeze(np.concatenate(self.noiseless_measurement, axis=0))
         np_tif.array_to_tif(
             nm, self.output_prefix + 'noiseless_measurement.tif')
     if hasattr(self, 'noisy_measurement'):
         nm = np.squeeze(np.concatenate(self.noisy_measurement, axis=0))
         np_tif.array_to_tif(nm,
                             self.output_prefix + 'noisy_measurement.tif')
     return None
def stack_registration(
    s,
    align_to_this_slice=0,
    refinement='spike_interpolation',
    register_in_place=True,
    fourier_cutoff_radius=None,
    background_subtraction=None,
    debug=False,
):
    """Calculate shifts which would register the slices of a
    three-dimensional stack `s`, and optionally register the stack in-place.

    Axis 0 is the "z-axis", axis 1 is the "up-down" (Y) axis, and axis 2
    is the "left-right" (X) axis. For each XY slice, we calculate the
    shift in the XY plane which would line that slice up with the slice
    specified by `align_to_this_slice`. If `align_to_this_slice` is a
    number, it indicates which slice of `s` to use as the reference
    slice. If `align_to_this_slice` is a numpy array, it is used as the
    reference slice, and must be the same shape as a 2D slice of `s`.

    `refinement` is one of `integer`, `spike_interpolation`, or
    `phase_fitting`, in order of increasing precision/slowness. I don't
    yet have any evidence that my implementation of phase fitting gives
    any improvement over (faster, simpler) spike interpolation, so
    caveat emptor.

    `register_in_place`: If `True`, modify the input stack `s` by
    shifting its slices to line up with the reference slice.

    `fourier_cutoff_radius`: Ignore the Fourier phases of spatial
    frequencies higher than this cutoff, since they're probably lousy
    due to aliasing and noise anyway. If `None`, attempt to estimate a
    resonable cutoff.

    'background_subtraction': One of None, 'mean', 'min', or
    'edge_mean'. Image registration is sensitive to edge effects. To
    combat this, we multiply the image by a real-space mask which goes
    to zero at the edges. For dim images on large DC backgrounds, the
    registration can end up mistaking this mask for an important image
    feature, distorting the registration. Sometimes it's helpful to
    subtract a background from the image before registration, to reduce
    this effect. 'mean' and 'min' subtract the mean and minimum of the
    stack 's', respectively, and 'edge_mean' subtracts the mean of the
    edge pixels. Use None (the default) for no background subtraction.
    """
    assert len(s.shape) == 3
    try:
        assert align_to_this_slice in range(s.shape[0])
        align_to_this_slice = s[align_to_this_slice, :, :]
    except ValueError:
        align_to_this_slice = np.squeeze(align_to_this_slice)
    assert align_to_this_slice.shape == s.shape[-2:]
    assert refinement in ('integer', 'spike_interpolation', 'phase_fitting')
    if refinement == 'phase_fitting' and minimize is None:
        raise UserWarning("Failed to import scipy minimize; no phase fitting.")
    assert register_in_place in (True, False)
    # What background should we subtract from each slice of the stack?
    assert background_subtraction in (None, 'mean', 'min', 'edge_mean')
    if background_subtraction is None:
        bg = 0
    elif background_subtraction is 'min':
        bg = s.min()
    elif background_subtraction is 'mean':
        bg = s.mean()
    elif background_subtraction is 'edge_mean':
        bg = np.mean(
            (s[:, 0, :].mean(), s[:, -1, :].mean(), s[:, :,
                                                      0].mean(), s[:, :,
                                                                   -1].mean()))
    if fourier_cutoff_radius is None:
        fourier_cutoff_radius = estimate_fourier_cutoff_radius(s, bg, debug)
    assert (0 < fourier_cutoff_radius <= 0.5)
    assert debug in (True, False)
    if debug and np_tif is None:
        raise UserWarning("Failed to import np_tif; no debug mode.")
    ## Multiply each slice of the stack by an XY mask that goes to zero
    ## at the edges, to prevent periodic boundary artifacts when we
    ## Fourier transform.
    mask_ud = np.sin(np.linspace(0, np.pi, s.shape[1])).reshape(s.shape[1], 1)
    mask_lr = np.sin(np.linspace(0, np.pi, s.shape[2])).reshape(1, s.shape[2])
    masked_reference_slice = (align_to_this_slice - bg) * mask_ud * mask_lr
    ## We'll base our registration on the phase of the low spatial
    ## frequencies of the cross-power spectrum. We'll need the complex
    ## conjugate of the Fourier transform of the masked reference slice,
    ## and a mask in the Fourier domain to pick out the low spatial
    ## frequencies:
    ref_slice_ft_conj = np.conj(np.fft.rfftn(masked_reference_slice))
    k_ud = np.fft.fftfreq(s.shape[1]).reshape(ref_slice_ft_conj.shape[0], 1)
    k_lr = np.fft.rfftfreq(s.shape[2]).reshape(1, ref_slice_ft_conj.shape[1])
    fourier_mask = (k_ud**2 + k_lr**2) < (fourier_cutoff_radius)**2
    ## Now we'll loop over each slice of the stack, calculate our
    ## registration shifts, and optionally apply the shifts to the
    ## original stack.
    registration_shifts = []
    if debug:
        ## Save some intermediate data to help with debugging
        masked_stack = np.zeros_like(s)
        masked_stack_ft = np.zeros((s.shape[0], ) + ref_slice_ft_conj.shape,
                                   dtype=np.complex128)
        masked_stack_ft_vs_ref = np.zeros_like(masked_stack_ft)
        cross_power_spectra = np.zeros_like(masked_stack_ft)
        spikes = np.zeros(s.shape, dtype=np.float64)
    for which_slice in range(s.shape[0]):
        if debug: print("Calculating registration for slice", which_slice)
        ## Compute the cross-power spectrum of our slice, and mask out
        ## the high spatial frequencies.
        current_slice = (s[which_slice, :, :] - bg) * mask_ud * mask_lr
        current_slice_ft = np.fft.rfftn(current_slice)
        cross_power_spectrum = current_slice_ft * ref_slice_ft_conj
        cross_power_spectrum = (fourier_mask * cross_power_spectrum /
                                np.abs(cross_power_spectrum))
        ## Inverse transform to get a 'spike' in real space. The
        ## location of this spike gives the desired registration shift.
        ## Start by locating the spike to the nearest integer:
        spike = np.fft.irfftn(cross_power_spectrum, s=current_slice.shape)
        loc = np.array(np.unravel_index(np.argmax(spike), spike.shape))
        if refinement in ('spike_interpolation', 'phase_fitting'):
            ## Use (very simple) three-point polynomial interpolation to
            ## refine the location of the peak of the spike:
            neighbors = np.array([-1, 0, 1])
            ud_vals = spike[(loc[0] + neighbors) % spike.shape[0], loc[1]]
            lr_vals = spike[loc[0], (loc[1] + neighbors) % spike.shape[1]]
            lr_fit = np.poly1d(np.polyfit(neighbors, lr_vals, deg=2))
            ud_fit = np.poly1d(np.polyfit(neighbors, ud_vals, deg=2))
            lr_max_shift = -lr_fit[1] / (2 * lr_fit[2])
            ud_max_shift = -ud_fit[1] / (2 * ud_fit[2])
            loc = loc + (ud_max_shift, lr_max_shift)
        ## Convert our shift into a signed number near zero:
        loc = ((np.array(spike.shape) // 2 + loc) % np.array(spike.shape) -
               np.array(spike.shape) // 2)
        if refinement == 'phase_fitting':
            if debug: print("Phase fitting slice", which_slice, "...")

            ## (Attempt to) further refine our registration shift by
            ## fitting Fourier phases. I'm not sure this does any good,
            ## perhaps my implementation is lousy?
            def minimize_me(loc, cross_power_spectrum):
                disagreement = np.abs(
                    expected_cross_power_spectrum(loc, k_ud, k_lr) -
                    cross_power_spectrum)[fourier_mask].sum()
                if debug: print(" Shift:", loc, "Disagreement:", disagreement)
                return disagreement

            loc = minimize(minimize_me,
                           x0=loc,
                           args=(cross_power_spectrum, ),
                           method='Nelder-Mead').x
        registration_shifts.append(loc)
        if debug:
            ## Save some intermediate data to help with debugging
            masked_stack[which_slice, :, :] = current_slice
            masked_stack_ft[which_slice, :, :] = (np.fft.fftshift(
                current_slice_ft, axes=0))
            masked_stack_ft_vs_ref[which_slice, :, :] = (np.fft.fftshift(
                current_slice_ft * ref_slice_ft_conj, axes=0))
            cross_power_spectra[which_slice, :, :] = (np.fft.fftshift(
                cross_power_spectrum, axes=0))
            spikes[which_slice, :, :] = np.fft.fftshift(spike)
    if register_in_place:
        ## Modify the input stack in-place so it's registered.
        if refinement == 'integer':
            registration_type = 'nearest_integer'
        else:
            registration_type = 'fourier_interpolation'
        apply_registration_shifts(s,
                                  registration_shifts,
                                  registration_type=registration_type)
    if debug:
        np_tif.array_to_tif(masked_stack, 'DEBUG_masked_stack.tif')
        np_tif.array_to_tif(np.log(np.abs(masked_stack_ft)),
                            'DEBUG_masked_stack_FT_log_magnitudes.tif')
        np_tif.array_to_tif(np.angle(masked_stack_ft),
                            'DEBUG_masked_stack_FT_phases.tif')
        np_tif.array_to_tif(np.angle(masked_stack_ft_vs_ref),
                            'DEBUG_masked_stack_FT_phase_vs_ref.tif')
        np_tif.array_to_tif(np.angle(cross_power_spectra),
                            'DEBUG_cross_power_spectral_phases.tif')
        np_tif.array_to_tif(spikes, 'DEBUG_spikes.tif')
        if register_in_place:
            np_tif.array_to_tif(s, 'DEBUG_registered_stack.tif')
    return np.array(registration_shifts)
                               np.fft.rfft(stack[0, :, :]).shape)
    k_ud, k_lr = np.fft.fftfreq(stack.shape[1]), np.fft.rfftfreq(
        stack.shape[2])
    k_ud, k_lr = k_ud.reshape(k_ud.size, 1), k_lr.reshape(1, k_lr.size)
    for s, (y, x) in enumerate(shifts):
        top = max(0, y)
        lef = max(0, x)
        bot = min(obj.shape[-2], obj.shape[-2] + y)
        rig = min(obj.shape[-1], obj.shape[-1] + x)
        shifted_obj.fill(0)
        shifted_obj[0, top:bot, lef:rig] = obj[0, top - y:bot - y,
                                               lef - x:rig - x]
        stack[s, :, :] = np.random.poisson(
            brightness * bucket(gaussian_filter(shifted_obj, blur),
                                bucket_size)[0, crop:-crop, crop:-crop])
        expected_phases[s, :, :] = np.angle(
            np.fft.fftshift(expected_cross_power_spectrum(
                (y / bucket_size[1], x / bucket_size[2]), k_ud, k_lr),
                            axes=0))
    np_tif.array_to_tif(expected_phases, 'DEBUG_expected_phase_vs_ref.tif')
    np_tif.array_to_tif(stack, 'DEBUG_stack.tif')
    print(" Done.")
    print("Registering test stack...")
    calculated_shifts = stack_registration(stack,
                                           refinement='spike_interpolation',
                                           debug=True)
    print(" Done.")
    for s, cs in zip(shifts, calculated_shifts):
        print('%0.2f (%i)' % (cs[0] * bucket_size[1], s[0]),
              '%0.2f (%i)' % (cs[1] * bucket_size[2], s[1]))
def stack_registration(s,
                       align_to_this_slice=0,
                       refinement='spike_interpolation',
                       register_in_place=True,
                       registered_stack_is_masked=False,
                       fourier_mask_magnitude=0.15,
                       debug=False):
    """Calculate shifts which would register the slices of a
    three-dimensional stack 's', and optionally register the stack in-place.

    Axis 0 is the "z-axis", axis 1 is the "up-down" (Y) axis, and axis 2
    is the "left-right" (X) axis. For each XY slice, we calculate the
    shift in the XY plane which would line that slice up with the slice
    specified by 'align_to_this_slice'.

    'refinement' is one of 'integer', 'spike_interpolation', or
    'phase_fitting', in order of increasing precision/slowness. I don't
    yet have any evidence that my implementation of phase fitting gives
    any improvement (faster, simpler) simple spike interpolation, so
    caveat emptor.

    'register_in_place': If 'True', modify the input stack 's' by
    shifting its slices to line up with the reference slice.

    'registered_stack_is_masked': We mask each slice of the stack so
    that it goes to zero at the edges, which reduces Fourier artifacts
    and improves registration accuracy. If we're also modifying the
    input stack, we can save one Fourier transform per iteration if
    we're willing to substitute the 'masked' version of each slice for
    its original value.

    'fourier_mask_magnitude': Ignore the Fourier phases of spatial
    frequencies above this cutoff, since they're probably lousy due to
    aliasing and noise anyway.
    """
    assert len(s.shape) == 3
    assert align_to_this_slice in range(s.shape[0])
    assert refinement in ('integer', 'spike_interpolation', 'phase_fitting')
    if refinement == 'phase_fitting' and minimize is None:
        raise UserWarning("Failed to import scipy minimize; no phase fitting.")
    assert register_in_place in (True, False)
    assert registered_stack_is_masked in (True, False)
    assert 0 < fourier_mask_magnitude < 0.5
    assert debug in (True, False)
    if debug and np_tif is None:
        raise UserWarning("Failed to import np_tif; no debug mode.")
    ## Multiply each slice of the stack by an XY mask that goes to zero
    ## at the edges, to prevent periodic boundary artifacts when we
    ## Fourier transform.
    mask_ud = np.sin(np.linspace(0, np.pi, s.shape[1])).reshape(s.shape[1], 1)
    mask_lr = np.sin(np.linspace(0, np.pi, s.shape[2])).reshape(1, s.shape[2])
    masked_reference_slice = s[align_to_this_slice, :, :] * mask_ud * mask_lr
    ## We'll base our registration on the phase of the low spatial
    ## frequencies of the cross-power spectrum. We'll need the complex
    ## conjugate of the Fourier transform of the masked reference slice,
    ## and a mask in the Fourier domain to pick out the low spatial
    ## frequencies:
    ref_slice_ft_conj = np.conj(np.fft.rfftn(masked_reference_slice))
    k_ud = np.fft.fftfreq(s.shape[1]).reshape(ref_slice_ft_conj.shape[0], 1)
    k_lr = np.fft.rfftfreq(s.shape[2]).reshape(1, ref_slice_ft_conj.shape[1])
    fourier_mask = (k_ud**2 + k_lr**2) < (fourier_mask_magnitude)**2

    ## We can also use these Fourier frequencies to define a convenience
    ## function that gives the expected spectral phase associated with
    ## an arbitrary subpixel shift:
    def expected_cross_power_spectrum(shift):
        shift_ud, shift_lr = shift
        return np.exp(-2j * np.pi * (k_ud * shift_ud + k_lr * shift_lr))

    ## Now we'll loop over each slice of the stack, calculate our
    ## registration shifts, and optionally apply the shifts to the
    ## original stack.
    registration_shifts = []
    if debug:
        ## Save some intermediate data to help with debugging
        masked_stack = np.zeros_like(s)
        masked_stack_ft = np.zeros((s.shape[0], ) + ref_slice_ft_conj.shape,
                                   dtype=np.complex128)
        cross_power_spectra = np.zeros(
            (s.shape[0], ) + ref_slice_ft_conj.shape, dtype=np.complex128)
        spikes = np.zeros_like(s)
    for which_slice in range(s.shape[0]):
        if debug: print("Slice", which_slice)
        if which_slice == align_to_this_slice and not debug:
            registration_shifts.append(np.array((0, 0)))
            if register_in_place and registered_stack_is_masked:
                s[which_slice, :, :] = masked_reference_slice
            continue
        ## Compute the cross-power spectrum of our slice, and mask out
        ## the high spatial frequencies.
        current_slice = s[which_slice, :, :] * mask_ud * mask_lr
        current_slice_ft = np.fft.rfftn(current_slice)
        cross_power_spectrum = current_slice_ft * ref_slice_ft_conj
        cross_power_spectrum = (fourier_mask * cross_power_spectrum /
                                np.abs(cross_power_spectrum))
        ## Inverse transform to get a 'spike' in real space. The
        ## location of this spike gives the desired registration shift.
        ## Start by locating the spike to the nearest integer:
        spike = np.fft.irfftn(cross_power_spectrum, s=current_slice.shape)
        loc = np.array(np.unravel_index(np.argmax(spike), spike.shape))
        if refinement in ('spike_interpolation', 'phase_fitting'):
            ## Use (very simple) three-point polynomial interpolation to
            ## refine the location of the peak of the spike:
            neighbors = np.array([-1, 0, 1])
            ud_vals = spike[(loc[0] + neighbors) % spike.shape[0], loc[1]]
            lr_vals = spike[loc[0], (loc[1] + neighbors) % spike.shape[1]]
            lr_fit = np.poly1d(np.polyfit(neighbors, lr_vals, deg=2))
            ud_fit = np.poly1d(np.polyfit(neighbors, ud_vals, deg=2))
            lr_max_shift = -lr_fit[1] / (2 * lr_fit[2])
            ud_max_shift = -ud_fit[1] / (2 * ud_fit[2])
            loc = loc + (ud_max_shift, lr_max_shift)
        ## Convert our shift into a signed number near zero:
        loc = ((np.array(spike.shape) // 2 + loc) % np.array(spike.shape) -
               np.array(spike.shape) // 2)
        if refinement == 'phase_fitting':
            if debug: print("Phase fitting slice", which_slice, "...")

            ## (Attempt to) further refine our registration shift by
            ## fitting Fourier phases. I'm not sure this does any good,
            ## perhaps my implementation is lousy?
            def minimize_me(loc, cross_power_spectrum):
                disagreement = np.abs(
                    expected_cross_power_spectrum(loc) -
                    cross_power_spectrum)[fourier_mask].sum()
                if debug: print(" Shift:", loc, "Disagreement:", disagreement)
                return disagreement

            loc = minimize(minimize_me,
                           x0=loc,
                           args=(cross_power_spectrum, ),
                           method='Nelder-Mead').x
        registration_shifts.append(loc)
        if register_in_place:
            ## Modify the input stack in-place so it's registered.
            phase_correction = expected_cross_power_spectrum(loc)
            if registered_stack_is_masked:
                ## If we're willing to tolerate a "masked" result, we
                ## can save one Fourier transform:
                s[which_slice, :, :] = np.fft.irfftn(
                    current_slice_ft / phase_correction,
                    s=current_slice.shape).real
            else:
                ## Slower, but probably the right way to do it:
                shift_me = s[which_slice, :, :]
                if not shift_me.dtype == np.float64:
                    shift_me = shift_me.astype(np.float64)
                s[which_slice, :, :] = np.fft.irfftn(
                    np.fft.rfftn(shift_me) / phase_correction,
                    s=current_slice.shape).real
        if debug:
            ## Save some intermediate data to help with debugging
            masked_stack[which_slice, :, :] = current_slice
            masked_stack_ft[which_slice, :, :] = (np.fft.fftshift(
                current_slice_ft, axes=0))
            cross_power_spectra[which_slice, :, :] = (np.fft.fftshift(
                cross_power_spectrum * fourier_mask, axes=0))
            spikes[which_slice, :, :] = np.fft.fftshift(spike)
    if debug:
        np_tif.array_to_tif(masked_stack, 'DEBUG_masked_stack.tif')
        np_tif.array_to_tif(np.log(np.abs(masked_stack_ft)),
                            'DEBUG_masked_stack_FT_log_magnitudes.tif')
        np_tif.array_to_tif(np.angle(masked_stack_ft),
                            'DEBUG_masked_stack_FT_phases.tif')
        np_tif.array_to_tif(np.angle(cross_power_spectra),
                            'DEBUG_cross_power_spectral_phases.tif')
        np_tif.array_to_tif(spikes, 'DEBUG_spikes.tif')
        if register_in_place:
            np_tif.array_to_tif(s, 'DEBUG_registered_stack.tif')
    return registration_shifts
Exemple #9
0
            for z_v in z_voltages:
                filename = (
                    'fluorescence' +
##                    'phase_angle_' + ang +
                    '_green' + gr_pow +
                    '_red' + rd_pow +
                    z_v +
                    '_up.tif')
                assert os.path.exists(filename)
                print("Loading", filename)
                data = np_tif.tif_to_array(filename).astype(np.float64)
                assert data.shape == (3*1, 128, 380)
                data = data.reshape(1, 3, 128, 380) # Stack to hyperstack
                data = data.mean(axis=0, keepdims=True) # Sum over reps
                data_list.append(data)
print("Done loading.")
data = np.concatenate(data_list)
##    print('data shape is',data.shape)
##    mean_subtracted = data - data.mean(axis=-3, keepdims=True)
tif_shape = (data.shape[0] * data.shape[1], data.shape[2], data.shape[3])
    
    


print("Saving...")
np_tif.array_to_tif(data.reshape(tif_shape),('dataset_green_all_powers_up.tif'))
##    np_tif.array_to_tif(mean_subtracted.reshape(tif_shape),
##                        'dataset_green'+gr_pow+'_mean_subtracted.tif')
##    np_tif.array_to_tif(data_controlled,('data_controlled_green'+gr_pow+'.tif'))
print("Done saving.")
Exemple #10
0
def main():

    num_angles = 32
    angles = range(num_angles)
    num_reps_original = 1000
    num_delays = 5
    image_h = 128
    image_w = 380
    less_rows = 3

    # define repetitions with dust particles crossing field of view or
    # extreme red power fluctuations
    gr_on_remove = {
        1: [[432, 435], [694, 719], [860, 865]],
        5: [[54, 57], [505, 508], [901, 906]],
        10: [[214, 218], [421, 428]],
        14: [[356, 360], [391, 394], [711, 713], [774, 802]],
        18: [[208, 210], [661, 667], [989, 992]],
        21: [[181, 187]],
        22: [[63, 70], [328, 333], [440, 451], [544, 557], [897, 902],
             [935, 964]],
        24: [[287, 306], [922, 924]],
        25: [[69, 73], [639, 675], [880, 898]],
        26: [[667, 677]],
        27: [[9, 16], [557, 560], [664, 669]],
        29: [[219, 221], [366, 369], [452, 458], [871, 875]],
    }
    gr_off_remove = {
        1: [[706, 708]],
        5: [[219, 222], [505, 510], [553, 557]],
        7: [[158, 165], [213, 220], [310, 316], [493, 497], [950, 961]],
        12: [[173, 176], [432, 434], [914, 922]],
        13: [[494, 527]],
        14: [[983, 987]],
        15: [[451, 458], [698, 715], [873, 883]],
        16: [[171, 178]],
        17: [[100, 104], [323, 327]],
        21: [[51, 56], [293, 295], [385, 390], [858, 864]],
        22: [[106, 109], [279, 285], [565, 580], [829, 834], [904, 924]],
    }

    green_powers = [
        '_0mW',
        '_1500mW',
    ]

    remove_dict_list = [gr_off_remove,
                        gr_on_remove]  # same order as green_powers

    red_powers = [
        '_300mW',
    ]
    rd_pow = red_powers[0]

    data_mean = np.zeros((
        len(green_powers),
        num_angles,
        num_delays,
        image_h - 2 * less_rows,
        image_w,
    ))

    for gr_pow_num, gr_pow in enumerate(green_powers):
        remove_range_dict = remove_dict_list[gr_pow_num]
        for ang in angles:
            filename = ('STE_' + 'darkfield_' + str(ang) + '_green' + gr_pow +
                        '_red' + rd_pow + '_many_delays.tif')
            assert os.path.exists(filename)
            print("Loading", filename)
            data = np_tif.tif_to_array(filename).astype(np.float64)
            assert data.shape == (num_delays * num_reps_original, image_h,
                                  image_w)
            # Stack to hyperstack
            data = data.reshape(num_reps_original, num_delays, image_h,
                                image_w)
            # crop data to remove over-exposed stuff
            data = data[:, :, less_rows:image_h - less_rows, :]

            # delete repetitions with dust particles crossing field of view
            if ang in remove_range_dict:
                remove_range_list = remove_range_dict[ang]
                for my_range in reversed(remove_range_list):
                    first = my_range[0]
                    last = my_range[1]
                    delete_length = last - first + 1
                    data = np.delete(data, first + np.arange(delete_length), 0)
            print(ang, data.shape)

            # Get the average pixel brightness in the background region of the
            # phase contrast data. We'll use it to account for laser intensity
            # fluctuations
            avg_laser_brightness = get_bg_level(data.mean(axis=(0, 1)))

            # scale all images to have the same background brightness. This
            # amounts to a correction of roughly 1% or less
            local_laser_brightness = get_bg_level(data)
            data = data * (avg_laser_brightness /
                           local_laser_brightness).reshape(
                               data.shape[0], data.shape[1], 1, 1)

            # Average data over repetitions
            data = data.mean(axis=0)
            # Put data in file to be saved
            data_mean[gr_pow_num, ang, ...] = data

        # save data for a particular green power
        print("Saving...")
        tif_shape = (num_angles * num_delays, image_h - 2 * less_rows, image_w)
        np_tif.array_to_tif(
            data_mean[gr_pow_num, :, :, :, :].reshape(tif_shape),
            ('dataset_green' + gr_pow + '.tif'))
        print("Done saving.")

    return None
Exemple #11
0
densities from simulated 3D SIM data. To save computation time, we'll
ignore the y-dimension, and simulate x-z data.

This script outputs a bunch of TIF files on disk. Use ImageJ to view
them.
"""

# Define a 2D x-z test object
print("Constructing test object")
n_z, n_x = 60, 60
true_density = np.zeros((n_z, n_x))
true_density[n_z // 2 + 5, n_x // 2 + 1] = 1
true_density[n_z // 2 - 5, n_x // 2 - 1] = 1
true_density[n_z // 2, ::4] = 1
true_density[n_z // 2 - 10, ::5] = 1
np_tif.array_to_tif(true_density, '1_true_density.tif')

# Define an x-z emission PSF
print("Constructing emission PSF")
na_limit = 0.25 * np.pi
k_magnitude = 0.15
k_z = np.fft.fftfreq(n_z).reshape(n_z, 1)
k_x = np.fft.fftfreq(n_x).reshape(1, n_x)
k_abs = np.sqrt(k_x**2 + k_z**2)
with np.errstate(divide='ignore',
                 invalid='ignore'):  # Ugly divide-by-zero code
    k_theta = np.nan_to_num(np.arccos(k_z / k_abs))
psf_field_ft = np.zeros((n_z, n_x), dtype=np.complex128)
psf_field_ft[(k_theta < na_limit) &  # Limited NA
             (np.abs(k_abs - k_magnitude) < 0.01)  # Monochromatic
             ] = 1
## If no hyperstack file then process original data
else:
    data_list = num_f*[None]
    for fn in range(num_f):
        input_filename_list[fn] = (input_filename
                                   %((fn-int(0.5*(num_f-1)))*mz_step))
        print('Found input files, loading...', end='')
        data_list[fn] = np_tif.tif_to_array(input_filename_list[fn])
        print('done')
    print('Creating np data array...', end='')
    data = np.asarray(data_list)
    data = data.reshape((num_f, num_slices) + data.shape[-2:])        
    print('done')
    print("Saving result...", end='')
    np_tif.array_to_tif(
        data.reshape(num_f*num_slices, data.shape[-2], data.shape[-1]),
        hyperstack_filename, slices=num_slices, channels=1, frames=num_f)
    print('done')

print('tif shape (Microsope z, RR z, y, x) =', data.shape)

## Add white scale bar to all images
for t in range(data.shape[0]):
    for z in range(data.shape[1]):
        image = data[t, z, :, :]
        image[50:60, 1800:1985] = 5000

## Choose parameters for video
current_frame = -1
xmargin = 0.15
ymargin = 0.15
Exemple #13
0
def stack_rotational_registration(
    s,
    align_to_this_slice=0,
    refinement='spike_interpolation',
    register_in_place=True,
    fourier_cutoff_radius=None,
    fail_180_test='fix_but_print_warning',
    debug=False,
):
    """Calculate rotations which would rotationally register the slices
    of a three-dimensional stack `s`, and optionally rotate the stack
    in-place.

    Axis 0 is the "z-axis", axis 1 is the "up-down" (Y) axis, and axis 2
    is the "left-right" (X) axis. For each XY slice, we calculate the
    rotation in the XY plane which would line that slice up with the
    slice specified by `align_to_this_slice`. If `align_to_this_slice`
    is a number, it indicates which slice of `s` to use as the reference
    slice. If `align_to_this_slice` is a numpy array, it is used as the
    reference slice, and must be the same shape as a 2D slice of `s`.

    `refinement` is pretty much just `spike_interpolation`, until I get
    around to adding other refinement techniques.

    `register_in_place`: If `True`, modify the input stack `s` by
    shifting its slices to line up with the reference slice.

    `fourier_cutoff_radius`: Ignore the Fourier amplitudes of spatial
    frequencies higher than this cutoff, since they're probably lousy
    due to aliasing and noise anyway. If `None`, attempt to estimate a
    resonable cutoff.

    'fail_180_test': One of 'fix_but_print_warning', (the default),
    'fix_silently', 'ignore_silently', or 'raise_exception'. The
    algorithm employed here has a fundamental ambiguity: due to
    symmetries in the Fourier domain under 180 degree rotations, the
    rotational registration won't return rotations bigger than +/- 90
    degrees. We currently have a lousy correlation-based check to detect
    if slice(s) in your stack seem to need bigger rotations to align
    with the reference slice. What would you like to happen when this
    warning triggers?
    """
    # TODO: take advantage of periodic/smooth decomposition
    assert len(s.shape) == 3
    try:
        assert align_to_this_slice in range(s.shape[0])
        align_to_this_slice = s[align_to_this_slice, :, :]
    except ValueError:  # Maybe align_to_this_slice is a numpy array?
        align_to_this_slice = np.squeeze(align_to_this_slice)
    assert align_to_this_slice.shape == s.shape[-2:]
    ## Create a square-cropped view 'c' of the input stack 's'. We're
    ## going to use a circular mask anyway, and this saves us from
    ## having to worry about nonuniform k_x, k_y sampling in Fourier
    ## space:
    delta = s.shape[1] - s.shape[2]
    y_slice, x_slice = slice(s.shape[1]), slice(s.shape[2])  # Default: no crop
    if delta > 0:  # Crop Y
        y_slice = slice(delta // 2, delta // 2 - delta)
    elif delta < 0:  # Crop X
        x_slice = slice(-delta // 2, delta - delta // 2)
    c = s[:, y_slice, x_slice]
    align_to_this_slice = align_to_this_slice[y_slice, x_slice]
    assert c.shape[1] == c.shape[2]  # Take this line out in a few months!
    assert refinement in ('spike_interpolation', )
    assert register_in_place in (True, False)
    if fourier_cutoff_radius is None:
        fourier_cutoff_radius = estimate_fourier_cutoff_radius(c, debug=debug)
    assert (0 < fourier_cutoff_radius <= 0.5)
    assert fail_180_test in ('fix_but_print_warning', 'fix_silently',
                             'ignore_silently', 'raise_exception')
    assert debug in (True, False)
    if debug and np_tif is None:
        raise UserWarning("Failed to import np_tif; no debug mode.")
    if map_coordinates is None:
        raise UserWarning("Failed to import scipy map_coordinates;" +
                          " no stack_rotational_registration.")
    ## We'll multiply each slice of the stack by a circular mask to
    ## prevent edge-effect artifacts when we Fourier transform:
    mask_ud = np.arange(-c.shape[1] / 2, c.shape[1] / 2).reshape(c.shape[1], 1)
    mask_lr = np.arange(-c.shape[2] / 2, c.shape[2] / 2).reshape(1, c.shape[2])
    mask_r_sq = mask_ud**2 + mask_lr**2
    max_r_sq = (c.shape[1] / 2)**2
    mask = (mask_r_sq <= max_r_sq) * np.cos(
        (np.pi / 2) * (mask_r_sq / max_r_sq))
    del mask_ud, mask_lr, mask_r_sq
    masked_reference_slice = align_to_this_slice * mask
    small_ref = bucket(masked_reference_slice, (4, 4))
    ## We'll base our rotational registration on the logarithms of the
    ## amplitudes of the low spatial frequencies of the reference and
    ## target slices. We'll need the amplitudes of the Fourier transform
    ## of the masked reference slice:
    ref_slice_ft_log_amp = np.log(
        np.abs(np.fft.fftshift(np.fft.rfftn(masked_reference_slice), axes=0)))
    ## Transform the reference slice log amplitudes to polar
    ## coordinates. Note that we avoid some potential subtleties here
    ## because we've cropped to a square field of view (which was the
    ## right thing to do anyway):
    n_y, n_x = ref_slice_ft_log_amp.shape
    k_r = np.arange(1, 2 * fourier_cutoff_radius * n_x)
    k_theta = np.linspace(-np.pi / 2, np.pi / 2,
                          np.ceil(2 * fourier_cutoff_radius * n_y))
    k_theta_delta_degrees = (k_theta[1] - k_theta[0]) * 180 / np.pi
    k_r, k_theta = k_r.reshape(len(k_r), 1), k_theta.reshape(1, len(k_theta))
    k_y = k_r * np.sin(k_theta) + n_y // 2
    k_x = k_r * np.cos(k_theta)
    del k_r, k_theta, n_y, n_x
    polar_ref = map_coordinates(ref_slice_ft_log_amp, (k_y, k_x))
    polar_ref_ft_conj = np.conj(np.fft.rfft(polar_ref, axis=1))
    n_y, n_x = polar_ref_ft_conj.shape
    y, x = np.arange(n_y).reshape(n_y, 1), np.arange(n_x).reshape(1, n_x)
    polar_fourier_mask = (y > x)  # Triangular half-space of good FT phases
    del n_y, n_x, y, x
    ## Now we'll loop over each slice of the stack, calculate our
    ## registration rotations, and optionally apply the rotations to the
    ## original stack.
    registration_rotations_degrees = []
    if debug:
        ## Save some intermediate data to help with debugging
        n_z = c.shape[0]
        masked_stack = np.zeros_like(c)
        masked_stack_ft_log_amp = np.zeros((n_z, ) +
                                           ref_slice_ft_log_amp.shape)
        polar_stack = np.zeros((n_z, ) + polar_ref.shape)
        polar_stack_ft = np.zeros((n_z, ) + polar_ref_ft_conj.shape,
                                  dtype=np.complex128)
        cross_power_spectra = np.zeros_like(polar_stack_ft)
        spikes = np.zeros_like(polar_stack)
    for which_slice in range(c.shape[0]):
        if debug:
            print("Calculating rotational registration for slice", which_slice)
        ## Compute the polar transform of the log of the amplitude
        ## spectrum of our masked slice:
        current_slice = c[which_slice, :, :] * mask
        current_slice_ft_log_amp = np.log(
            np.abs(np.fft.fftshift(np.fft.rfftn(current_slice), axes=0)))
        polar_slice = map_coordinates(current_slice_ft_log_amp, (k_y, k_x))
        ## Register the polar transform of the current slice against the
        ## polar transform of the reference slice, using a similar
        ## algorithm to the one in mr_stacky:
        polar_slice_ft = np.fft.rfft(polar_slice, axis=1)
        cross_power_spectrum = polar_slice_ft * polar_ref_ft_conj
        cross_power_spectrum = (polar_fourier_mask * cross_power_spectrum /
                                np.abs(cross_power_spectrum))
        ## Inverse transform to get a 'spike' in each row of fixed k_r;
        ## the location of this spike gives the desired rotation in
        ## theta pixels, which we can convert back to an angle. Start by
        ## locating the spike to the nearest integer:
        spike = np.fft.irfft(cross_power_spectrum,
                             axis=1,
                             n=polar_slice.shape[1])
        spike_1d = spike.sum(axis=0)
        loc = np.argmax(spike_1d)
        if refinement is 'spike_interpolation':
            ## Use (very simple) three-point polynomial interpolation to
            ## refine the location of the peak of the spike:
            neighbors = np.array([-1, 0, 1])
            lr_vals = spike_1d[(loc + neighbors) % spike.shape[1]]
            lr_fit = np.poly1d(np.polyfit(neighbors, lr_vals, deg=2))
            lr_max_shift = -lr_fit[1] / (2 * lr_fit[2])
            loc += lr_max_shift
        ## Convert our shift into a signed number near zero:
        loc = (loc +
               spike.shape[1] // 2) % spike.shape[1] - spike.shape[1] // 2
        ## Convert this shift in "theta pixels" back to a rotation in degrees:
        angle = loc * k_theta_delta_degrees
        ## There's a fundamental 180-degree ambiguity in the rotation
        ## determined by this algorthim. We need to check if adding 180
        ## degrees explains our data better. This is a half-assed fast
        ## check that just downsamples the bejesus out of the relevant
        ## slices and looks at cross-correlation.
        ## TODO: FIX THIS FOR REAL! ...or does it matter? Testing, for now.
        small_current_slice = bucket(current_slice, (4, 4))
        small_cur_000 = rotate(small_current_slice, angle=angle, reshape=False)
        small_cur_180 = rotate(small_current_slice,
                               angle=angle + 180,
                               reshape=False)
        if (small_cur_180 * small_ref).sum() > (small_cur_000 *
                                                small_ref).sum():
            if fail_180_test is 'fix_but_print_warning':
                angle += 180
                print(" **Warning: potentially ambiguous rotation detected**")
                print("   Inspect the registration for rotations off by 180" +
                      " degrees, at slice %i" % (which_slice))
            elif fail_180_test is 'fix_silently':
                angle += 180
            elif fail_180_test is 'ignore_silently':
                pass
            elif fail_180_test is 'raise_exception':
                raise UserWarning(
                    "Potentially ambiguous rotation detected.\n" +
                    "One of your slices needed more than 90 degrees rotation")
        ## Pencils down.
        registration_rotations_degrees.append(angle)
        if debug:
            ## Save some intermediate data to help with debugging
            masked_stack[which_slice, :, :] = current_slice
            masked_stack_ft_log_amp[which_slice, :, :] = (
                current_slice_ft_log_amp)
            polar_stack[which_slice, :, :] = polar_slice
            polar_stack_ft[which_slice, :, :] = polar_slice_ft
            cross_power_spectra[which_slice, :, :] = cross_power_spectrum
            spikes[which_slice, :, :] = np.fft.fftshift(spike, axes=1)
    if register_in_place:
        ## Rotate the slices of the input stack in-place so it's
        ## rotationally registered:
        for which_slice in range(s.shape[0]):
            s[which_slice, :, :] = rotate(
                s[which_slice, :, :],
                angle=registration_rotations_degrees[which_slice],
                reshape=False)
    if debug:
        np_tif.array_to_tif(masked_stack, 'DEBUG_masked_stack.tif')
        np_tif.array_to_tif(masked_stack_ft_log_amp,
                            'DEBUG_masked_stack_FT_log_magnitudes.tif')
        np_tif.array_to_tif(polar_stack, 'DEBUG_polar_stack.tif')
        np_tif.array_to_tif(np.angle(polar_stack_ft),
                            'DEBUG_polar_stack_FT_phases.tif')
        np_tif.array_to_tif(np.log(np.abs(polar_stack_ft)),
                            'DEBUG_polar_stack_FT_log_magnitudes.tif')
        np_tif.array_to_tif(np.angle(cross_power_spectra),
                            'DEBUG_cross_power_spectral_phases.tif')
        np_tif.array_to_tif(spikes, 'DEBUG_spikes.tif')
        if register_in_place:
            np_tif.array_to_tif(s, 'DEBUG_registered_stack.tif')
    return np.array(registration_rotations_degrees)
Exemple #14
0
def stack_registration(
    s,
    align_to_this_slice=0,
    refinement='spike_interpolation',
    register_in_place=True,
    fourier_cutoff_radius=None,
    debug=False,
):
    """Calculate shifts which would register the slices of a
    three-dimensional stack `s`, and optionally register the stack in-place.

    Axis 0 is the "z-axis", axis 1 is the "up-down" (Y) axis, and axis 2
    is the "left-right" (X) axis. For each XY slice, we calculate the
    shift in the XY plane which would line that slice up with the slice
    specified by `align_to_this_slice`. If `align_to_this_slice` is an
    integer, it indicates which slice of `s` to use as the reference
    slice. If `align_to_this_slice` is a numpy array, it is used as the
    reference slice, and must be the same shape as a 2D slice of `s`.

    `refinement` is one of `integer`, (fast registration to the nearest
    pixel) or `spike_interpolation` (slower but hopefully more accurate
    sub-pixel registration).

    `register_in_place`: If `True`, modify the input stack `s` by
    shifting its slices to line up with the reference slice.

    `fourier_cutoff_radius`: Ignore the Fourier phases of spatial
    frequencies higher than this cutoff, since they're probably lousy
    due to aliasing and noise anyway. If `None`, attempt to estimate a
    resonable cutoff.

    `debug`: (Attempt to) output several TIF files that are often useful
    when stack_registration isn't working. This needs to import
    np_tif.py; you can get a copy of np_tif.py from
    https://github.com/AndrewGYork/tools/blob/master/np_tif.py
    Probably the best way to understand what these files are, and how to
    interpret them, is to read the code of the rest of this function.
    """
    assert len(s.shape) == 3
    try:
        assert align_to_this_slice in range(s.shape[0])
        skip_aligning_this_slice = align_to_this_slice
        align_to_this_slice = s[align_to_this_slice, :, :]
    except ValueError:
        skip_aligning_this_slice = None
        align_to_this_slice = np.squeeze(align_to_this_slice)
    assert align_to_this_slice.shape == s.shape[-2:]
    assert refinement in ('integer', 'spike_interpolation')
    assert register_in_place in (True, False)
    if fourier_cutoff_radius is not None:
        assert (0 < fourier_cutoff_radius <= 0.5)
    assert debug in (True, False)
    if debug and np_tif is None:
        raise UserWarning("Failed to import np_tif; no debug mode.")
    ## We'll base our registration on the phase of the low spatial
    ## frequencies of the cross-power spectrum. We'll need the complex
    ## conjugate of the Fourier transform of the periodic version of the
    ## reference slice, and a mask in the Fourier domain to pick out the
    ## low spatial frequencies:
    ref_slice_rfft = np.fft.rfftn(align_to_this_slice)
    if fourier_cutoff_radius is None:  # Attempt to estimate a sensible FT radius
        fourier_cutoff_radius = estimate_fourier_cutoff_radius(
            ref_slice_rfft,
            debug=debug,
            input_is_rfftd=True,
            shape=s.shape[1:])
    ref_slice_ft_conj = np.conj(ref_slice_rfft -
                                _smooth_rfft2(align_to_this_slice))
    del ref_slice_rfft
    k_ud = np.fft.fftfreq(s.shape[1]).reshape(ref_slice_ft_conj.shape[0], 1)
    k_lr = np.fft.rfftfreq(s.shape[2]).reshape(1, ref_slice_ft_conj.shape[1])
    fourier_mask = (k_ud**2 + k_lr**2) < (fourier_cutoff_radius)**2
    ## Now we'll loop over each slice of the stack, calculate our
    ## registration shifts, and optionally apply the shifts to the
    ## original stack.
    registration_shifts = []
    if debug:
        ## Save some intermediate data to help with debugging
        stack_ft = np.zeros((s.shape[0], ) + ref_slice_ft_conj.shape,
                            dtype=np.complex128)
        periodic_stack = np.zeros_like(s)
        periodic_stack_ft = np.zeros((s.shape[0], ) + ref_slice_ft_conj.shape,
                                     dtype=np.complex128)
        periodic_stack_ft_vs_ref = np.zeros_like(periodic_stack_ft)
        cross_power_spectra = np.zeros_like(periodic_stack_ft)
        spikes = np.zeros(s.shape, dtype=np.float64)
    for which_slice in range(s.shape[0]):
        if which_slice == skip_aligning_this_slice and not debug:
            registration_shifts.append(np.zeros(2))
            continue
        if debug: print("Calculating registration for slice", which_slice)
        ## Compute the cross-power spectrum of the periodic component of
        ## our slice with the periodic component of the reference slice,
        ## and mask out the high spatial frequencies.
        current_slice_ft = np.fft.rfftn(s[which_slice, :, :])
        current_slice_ft_periodic = (current_slice_ft -
                                     _smooth_rfft2(s[which_slice, :, :]))
        cross_power_spectrum = current_slice_ft_periodic * ref_slice_ft_conj
        norm = np.abs(cross_power_spectrum)
        norm[norm == 0] = 1
        cross_power_spectrum = (fourier_mask * cross_power_spectrum / norm)
        ## Inverse transform to get a 'spike' in real space. The
        ## location of this spike gives the desired registration shift.
        ## Start by locating the spike to the nearest integer:
        spike = np.fft.irfftn(cross_power_spectrum, s=s.shape[1:])
        loc = np.array(np.unravel_index(np.argmax(spike), spike.shape))
        if refinement == 'spike_interpolation':
            ## Use (very simple) three-point polynomial interpolation to
            ## refine the location of the peak of the spike:
            neighbors = np.array([-1, 0, 1])
            ud_vals = spike[(loc[0] + neighbors) % spike.shape[0], loc[1]]
            lr_vals = spike[loc[0], (loc[1] + neighbors) % spike.shape[1]]
            lr_fit = np.poly1d(np.polyfit(neighbors, lr_vals, deg=2))
            ud_fit = np.poly1d(np.polyfit(neighbors, ud_vals, deg=2))
            lr_max_shift = -lr_fit[1] / (2 * lr_fit[2])
            ud_max_shift = -ud_fit[1] / (2 * ud_fit[2])
            loc = loc + (ud_max_shift, lr_max_shift)
        ## Convert our shift into a signed number near zero:
        loc = ((np.array(spike.shape) // 2 + loc) % np.array(spike.shape) -
               np.array(spike.shape) // 2)
        registration_shifts.append(loc)
        if register_in_place:
            ## Modify the input stack in-place so it's registered.
            apply_registration_shifts(
                s[which_slice:which_slice + 1, :, :], [loc],
                registration_type=('nearest_integer' if refinement == 'integer'
                                   else 'fourier_interpolation'),
                s_rfft=np.expand_dims(current_slice_ft, axis=0))
        if debug:
            ## Save some intermediate data to help with debugging
            stack_ft[which_slice, :, :] = (np.fft.fftshift(current_slice_ft,
                                                           axes=0))
            periodic_stack[which_slice, :, :] = np.fft.irfftn(
                current_slice_ft_periodic, s=s.shape[1:]).real
            periodic_stack_ft[which_slice, :, :] = (np.fft.fftshift(
                current_slice_ft_periodic, axes=0))
            periodic_stack_ft_vs_ref[which_slice, :, :] = (np.fft.fftshift(
                current_slice_ft_periodic * ref_slice_ft_conj, axes=0))
            cross_power_spectra[which_slice, :, :] = (np.fft.fftshift(
                cross_power_spectrum, axes=0))
            spikes[which_slice, :, :] = np.fft.fftshift(spike)
    if debug:
        np_tif.array_to_tif(periodic_stack, 'DEBUG_1_periodic_stack.tif')
        np_tif.array_to_tif(np.log(1 + np.abs(stack_ft)),
                            'DEBUG_2_stack_FT_log_magnitudes.tif')
        np_tif.array_to_tif(np.angle(stack_ft), 'DEBUG_3_stack_FT_phases.tif')
        np_tif.array_to_tif(np.log(1 + np.abs(periodic_stack_ft)),
                            'DEBUG_4_periodic_stack_FT_log_magnitudes.tif')
        np_tif.array_to_tif(np.angle(periodic_stack_ft),
                            'DEBUG_5_periodic_stack_FT_phases.tif')
        np_tif.array_to_tif(np.angle(periodic_stack_ft_vs_ref),
                            'DEBUG_6_periodic_stack_FT_phase_vs_ref.tif')
        np_tif.array_to_tif(np.angle(cross_power_spectra),
                            'DEBUG_7_cross_power_spectral_phases.tif')
        np_tif.array_to_tif(spikes, 'DEBUG_8_spikes.tif')
        if register_in_place:
            np_tif.array_to_tif(s, 'DEBUG_9_registered_stack.tif')
    return np.array(registration_shifts)
Exemple #15
0
def generate_psfs(
    shape,  #Desired pixel dimensions of the psfs
    excitation_brightness,  #Peak brightness in saturation units
    depletion_brightness,  #Peak brightness in saturation units
    blur_sigma,
    psf_type='point',
    output_dir=None,
    verbose=True,
):
    """
    A utility function used by psf_report().
    """
    # Calculate gaussian point and line excitation patterns. Pixel
    # values encode fluence per pulse, in saturation units.
    if psf_type == 'point':
        excitation_psf_point = np.zeros(shape)
        excitation_psf_point[0, shape[1] // 2, shape[2] // 2] = 1
        excitation_psf_point = gaussian_filter(excitation_psf_point,
                                               sigma=blur_sigma)
        excitation_psf_point *= excitation_brightness / excitation_psf_point.max(
        )
    if psf_type == 'line':
        excitation_psf_line = np.zeros(shape)
        excitation_psf_line[0, :, shape[2] // 2] = 1
        excitation_psf_line = gaussian_filter(excitation_psf_line,
                                              sigma=(0, 0, blur_sigma))
        excitation_psf_line *= excitation_brightness / excitation_psf_line.max(
        )
    # Calculate difference-of-gaussian point and line depletion
    # patterns. Pixel values encode fluence per pulse, in saturation
    # units.
    if psf_type == 'point':
        depletion_psf_inner = np.zeros(shape)
        depletion_psf_inner[0, shape[1] // 2, shape[2] // 2] = 1
        depletion_psf_inner = gaussian_filter(depletion_psf_inner,
                                              sigma=blur_sigma)
        depletion_psf_outer = gaussian_filter(depletion_psf_inner,
                                              sigma=blur_sigma)
        depletion_psf_point = (
            (depletion_psf_outer / depletion_psf_outer.max()) -
            (depletion_psf_inner / depletion_psf_inner.max()))
        depletion_psf_point *= depletion_brightness / depletion_psf_point.max()
    elif psf_type == 'line':
        depletion_psf_inner = np.zeros(shape)
        depletion_psf_inner[0, :, shape[2] // 2] = 1
        depletion_psf_inner = gaussian_filter(depletion_psf_inner,
                                              sigma=(0, 0, blur_sigma))
        depletion_psf_outer = gaussian_filter(depletion_psf_inner,
                                              sigma=(0, 0, blur_sigma))
        depletion_psf_line = (
            (depletion_psf_outer / depletion_psf_outer.max()) -
            (depletion_psf_inner / depletion_psf_inner.max()))
        depletion_psf_line *= depletion_brightness / depletion_psf_line.max()
    # Calculate "saturated" excitation/depletion patterns. Pixel values
    # encode probability per pulse that a ground-state molecule will
    # become excited (excitation) or an excited molecule will remain
    # excited (depletion).
    half_on_dose = 1
    half_off_dose = 1
    if psf_type == 'point':
        saturated_excitation_psf_point = 1 - 2**(-excitation_psf_point /
                                                 half_on_dose)
        saturated_depletion_psf_point = 2**(-depletion_psf_point /
                                            half_off_dose)
    elif psf_type == 'line':
        saturated_excitation_psf_line = 1 - 2**(-excitation_psf_line /
                                                half_on_dose)
        saturated_depletion_psf_line = 2**(-depletion_psf_line / half_off_dose)
    # Calculate post-depletion excitation patterns. Pixel values encode
    # probability per pulse that a molecule will become excited, but not
    # be depleted.
    if psf_type == 'point':
        sted_psf_point = (saturated_excitation_psf_point *
                          saturated_depletion_psf_point)
    elif psf_type == 'line':
        sted_psf_line = (saturated_excitation_psf_line *
                         saturated_depletion_psf_line)
    # Calculate the "system" PSF, which can depend on both excitation
    # and emission. For descanned point-STED, the system PSF is the
    # (STED-shrunk) excitation PSF. For rescanned line-STED, the system
    # PSF also involves the emission PSF.
    if psf_type == 'point':
        descanned_point_sted_psf = sted_psf_point  # Simple rename
    elif psf_type == 'line':
        emission_sigma = blur_sigma  # Assume emission PSF same as excitation PSF
        line_sted_sigma, _ = get_width(
            sted_psf_line[0, sted_psf_line.shape[1] // 2, :])
        line_rescan_ratio = (emission_sigma / line_sted_sigma)**2 + 1
        if verbose:
            print(" Ideal line rescan ratio: %0.5f" % (line_rescan_ratio))
        line_rescan_ratio = int(np.round(line_rescan_ratio))
        if verbose: print(" Neareset integer:", line_rescan_ratio)
        point_obj = np.zeros(shape)
        point_obj[0, point_obj.shape[1] // 2, point_obj.shape[2] // 2] = 1
        emission_psf = gaussian_filter(point_obj, sigma=emission_sigma)
        rescanned_signal_inst = np.zeros(
            (point_obj.shape[0], point_obj.shape[1],
             int(line_rescan_ratio * point_obj.shape[2])))
        rescanned_signal_cumu = rescanned_signal_inst.copy()
        descanned_signal_cumu = np.zeros(shape)
        if verbose: print(" Calculating rescan psf...", end='')
        # I could use an analytical shortcut to calculate the rescan PSF
        # (see http://dx.doi.org/10.1364/BOE.4.002644 for details), but
        # for the sake of clarity, I explicitly simulate the rescan
        # imaging process:
        for scan_position in range(point_obj.shape[2]):
            # . Scan the excitation
            scanned_excitation = np.roll(sted_psf_line,
                                         scan_position -
                                         point_obj.shape[2] // 2,
                                         axis=2)
            # . Multiply the object by excitation to calculate the "glow":
            glow = point_obj * scanned_excitation
            # . Blur the glow by the emission PSF to calculate the image
            #   on the detector:
            blurred_glow = gaussian_filter(glow, sigma=emission_sigma)
            # . Calculate the contribution to the descanned image (the
            #   kind measured by Curdt or Schubert
            #   http://www.ub.uni-heidelberg.de/archiv/14362
            #   http://www.ub.uni-heidelberg.de/archiv/15986
            descanned_signal = np.roll(blurred_glow,
                                       point_obj.shape[2] // 2 - scan_position,
                                       axis=2)
            descanned_signal_cumu[:, :, scan_position] += descanned_signal.sum(
                axis=2)
            # . Roll the descanned image to the rescan position, to
            #   produce the "instantaneous" rescanned image:
            rescanned_signal_inst.fill(0)
            rescanned_signal_inst[0, :, :point_obj.shape[2]] = descanned_signal
            rescanned_signal_inst = np.roll(rescanned_signal_inst,
                                            scan_position * line_rescan_ratio -
                                            point_obj.shape[2] // 2,
                                            axis=2)
            # . Add the "instantaneous" image to the "cumulative" image.
            rescanned_signal_cumu += rescanned_signal_inst
        if verbose: print(" ...done.")
        # . Bin the rescanned psf back to the same dimensions as the object:
        rescanned_line_sted_psf = np.roll(
            rescanned_signal_cumu,  #Roll so center bin is centered on the image
            int(line_rescan_ratio // 2),
            axis=2).reshape(  #Quick and dirty binning
                1, rescanned_signal_cumu.shape[1],
                int(rescanned_signal_cumu.shape[2] / line_rescan_ratio),
                int(line_rescan_ratio)).sum(axis=3)
        descanned_line_sted_psf = descanned_signal_cumu  # Simple rename
    if output_dir is not None:
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)
        if psf_type == 'point':
            for array, filename in ((excitation_psf_point,
                                     'excitation_psf_point.tif'),
                                    (depletion_psf_point,
                                     'depletion_psf_point.tif'),
                                    (saturated_excitation_psf_point,
                                     'excitation_fraction_psf_point.tif'),
                                    (saturated_depletion_psf_point,
                                     'depletion_fraction_psf_point.tif'),
                                    (sted_psf_point, 'sted_psf_point.tif')):
                np_tif.array_to_tif(array, os.path.join(output_dir, filename))
        elif psf_type == 'line':
            for array, filename in ((excitation_psf_line,
                                     'excitation_psf_line.tif'),
                                    (depletion_psf_line,
                                     'depletion_psf_line.tif'),
                                    (saturated_excitation_psf_line,
                                     'excitation_fraction_psf_line.tif'),
                                    (saturated_depletion_psf_line,
                                     'depletion_fraction_psf_line.tif'),
                                    (sted_psf_line, 'sted_psf_line.tif'),
                                    (emission_psf, 'emission_psf.tif'),
                                    (rescanned_signal_cumu,
                                     'sted_psf_line_rescan_unscaled.tif'),
                                    (rescanned_line_sted_psf,
                                     'sted_psf_line_rescan.tif'),
                                    (descanned_line_sted_psf,
                                     'sted_psf_line_descan.tif')):
                np_tif.array_to_tif(array, os.path.join(output_dir, filename))
    if psf_type == 'point':
        return {
            'excitation': excitation_psf_point,
            'depletion': depletion_psf_point,
            'excitation_fraction': saturated_excitation_psf_point,
            'depletion_fraction': saturated_depletion_psf_point,
            'sted': sted_psf_point,
            'descan_sted': descanned_point_sted_psf
        }
    elif psf_type == 'line':
        return {
            'excitation': excitation_psf_line,
            'depletion': depletion_psf_line,
            'excitation_fraction': saturated_excitation_psf_line,
            'depletion_fraction': saturated_depletion_psf_line,
            'sted': sted_psf_line,
            'descan_sted': descanned_line_sted_psf,
            'rescan_sted': rescanned_line_sted_psf
        }
else:
    print('Loading original file...', end='', sep='')
    data = np_tif.tif_to_array(input_filename)
    print('done')
    data = data.reshape((num_tps, ) + data.shape[-2:])
    print('tif shape (t, y, x) =', data.shape)
    print('Cropping...', end='', sep='')
    if left_crop or right_crop > 0:
        data = data[:, :, left_crop:-right_crop]
    if top_crop or bottom_crop > 0:
        data = data[:, top_crop:-bottom_crop, :]
    print('done')
    print("Saving result...", end='', sep='')
    np_tif.array_to_tif(data.reshape(num_tps, data.shape[-2], data.shape[-1]),
                        cropped_filename,
                        slices=1,
                        channels=1,
                        frames=num_tps)
    print('done')
    print('tif shape (t, y, x) =', data.shape)

## Choose parameters for video
current_frame = -1
xmargin = 0.01
ymargin = 0.025
space = 0.175
img_size = 0.5
max_intensity = 12500
wlw = 2  # white line width half amplitude
start_tp = 0  # removing uneventful begging
stop_tp = 100  # remove uneventful end
Exemple #17
0
# crop data (rectangle around bright fluorescent lobe)
data_rep_cropped = data_rep[:, :, :, :, 49:97, 147:195]
data_rep_signal = data_rep_cropped.mean(axis=5).mean(axis=4)

# get data background
data_rep_bg = data_rep[:, :, :, :, 20:30, 20:30]
data_rep_bg = data_rep_bg.mean(axis=5).mean(axis=4)

# compute repetition average of data after image registration
data_avg = data_rep.mean(axis=0)
##representative_image_avg = (data_avg[0,-1,-1,16:128,108:238] -
##                            data_rep_bg[:,0,-1,-1].mean(axis=0))
rep_image_single_shot = (data_rep[0, -1, -1, -1, 16:128, 108:238] -
                         data_rep_bg[0, -1, -1, -1])

data_avg_tif_shape = (data_avg.shape[0] * data_avg.shape[1] *
                      data_avg.shape[2], data_avg.shape[3], data_avg.shape[4])

point_data_tif_shape = (data_rep_signal.shape[0] * data_rep_signal.shape[1],
                        data_rep_signal.shape[2], data_rep_signal.shape[3])

print("Saving...")
np_tif.array_to_tif(data_rep_signal.reshape(point_data_tif_shape),
                    'data_point_signal.tif')
np_tif.array_to_tif(data_rep_bg.reshape(point_data_tif_shape),
                    'data_point_bg.tif')
##np_tif.array_to_tif(
##    representative_image_avg,'representative_image_avg.tif')
np_tif.array_to_tif(rep_image_single_shot, 'rep_image_single_shot.tif')
print("... done.")
def main():
    # each raw data stack has a full red and green power scan with red
    # varying slowly and green varying more quickly and green/red pulse
    # delay varying the quickest (5 delays, middle delay is 0 delay)

    num_reps = 200  # number power scans taken
    num_red_powers = 7
    num_green_powers = 13
    num_delays = 5
    image_h = 128
    image_w = 380
    less_rows = 3  # top/bottom 3 rows may contain leakage from outside pixels
    top = less_rows
    bot = image_h - less_rows

    # assume no sample motion during a single power scan
    # allocate hyperstack to carry power/delay-averaged images for registration
    data_rep = np.zeros((
        num_reps,
        image_h - less_rows * 2,
        image_w,
    ),
                        dtype=np.float64)
    data_rep_bg = np.zeros((
        num_reps,
        image_h - less_rows * 2,
        image_w,
    ),
                           dtype=np.float64)

    # allocate array to carry a number corresponding to the average red
    # beam brightness for each red power
    red_avg_brightness = np.zeros((num_red_powers))

    # populate hyperstack from data
    for rep_num in range(num_reps):
        filename = 'STE_darkfield_power_delay_scan_' + str(rep_num) + '.tif'
        print("Loading", filename)
        imported_power_scan = np_tif.tif_to_array(filename).astype(
            np.float64)[:, top:bot, :]
        red_avg_brightness += get_bg_level(
            imported_power_scan.reshape(
                num_red_powers, num_green_powers, num_delays,
                image_h - less_rows * 2,
                image_w).mean(axis=1).mean(axis=1)) / (2 * num_reps)
        data_rep[rep_num, :, :] = imported_power_scan.mean(axis=0)
        filename_bg = ('STE_darkfield_power_delay_scan_' + str(rep_num) +
                       '_green_blocked.tif')
        print("Loading", filename_bg)
        imported_power_scan_bg = np_tif.tif_to_array(filename_bg).astype(
            np.float64)[:, top:bot, :]
        red_avg_brightness += get_bg_level(
            imported_power_scan_bg.reshape(
                num_red_powers, num_green_powers, num_delays,
                image_h - less_rows * 2,
                image_w).mean(axis=1).mean(axis=1)) / (2 * num_reps)
        data_rep_bg[rep_num, :, :] = imported_power_scan_bg.mean(axis=0)

    # reshape red_avg_brightness to add a dimension for multiplication
    # with a brightness array with dimensions num_red_powers X num_green
    # powers X num_delays
    red_avg_brightness = red_avg_brightness.reshape(num_red_powers, 1, 1)

    # pick image/slice for all stacks to align to
    representative_rep_num = 0
    align_slice = data_rep[representative_rep_num, :, :]

    # save pre-registered average data (all powers for each rep)
    np_tif.array_to_tif(data_rep, 'dataset_not_registered_power_avg.tif')
    np_tif.array_to_tif(data_rep_bg,
                        'dataset_green_blocked_not_registered_power_avg.tif')

    # compute registration shifts
    print("Computing registration shifts...")
    shifts = stack_registration(data_rep,
                                align_to_this_slice=align_slice,
                                refinement='integer',
                                register_in_place=True,
                                background_subtraction='edge_mean')
    print("Computing registration shifts (no green) ...")
    shifts_bg = stack_registration(data_rep_bg,
                                   align_to_this_slice=align_slice,
                                   refinement='integer',
                                   register_in_place=True,
                                   background_subtraction='edge_mean')

    # save registered average data (all powers for each rep) and shifts
    np_tif.array_to_tif(data_rep, 'dataset_registered_power_avg.tif')
    np_tif.array_to_tif(data_rep_bg,
                        'dataset_green_blocked_registered_power_avg.tif')
    np_tif.array_to_tif(shifts, 'shifts.tif')
    np_tif.array_to_tif(shifts_bg, 'shifts_bg.tif')

    # now apply shifts to raw data and compute space-averaged signal
    # and representative images

    # define box around main lobe for computing space-averaged signal
    rect_top = 44
    rect_bot = 102
    rect_left = 172
    rect_right = 228

    # initialize hyperstacks for signal (with/without green light)
    print('Applying shifts to raw data...')
    signal = np.zeros((
        num_reps,
        num_red_powers,
        num_green_powers,
        num_delays,
    ),
                      dtype=np.float64)
    signal_bg = np.zeros((
        num_reps,
        num_red_powers,
        num_green_powers,
        num_delays,
    ),
                         dtype=np.float64)
    data_hyper_shape = (num_red_powers, num_green_powers, num_delays, image_h,
                        image_w)

    # get representative image cropping coordinates
    rep_top = 22
    rep_bot = 122
    rep_left = 136
    rep_right = 262

    # initialize representative images (with/without green light)
    darkfield_image = np.zeros(
        (  #num_reps,
            rep_bot - rep_top,
            rep_right - rep_left,
        ),
        dtype=np.float64)
    STE_image = np.zeros(
        (  #num_reps,
            rep_bot - rep_top,
            rep_right - rep_left,
        ),
        dtype=np.float64)
    darkfield_image_bg = np.zeros(
        (  #num_reps,
            rep_bot - rep_top,
            rep_right - rep_left,
        ),
        dtype=np.float64)
    STE_image_bg = np.zeros(
        (  #num_reps,
            rep_bot - rep_top,
            rep_right - rep_left,
        ),
        dtype=np.float64)

    # finally apply shifts and compute output data
    for rep_num in range(num_reps):
        filename = 'STE_darkfield_power_delay_scan_' + str(rep_num) + '.tif'
        data = np_tif.tif_to_array(filename).astype(np.float64)[:, top:bot, :]
        filename_bg = ('STE_darkfield_power_delay_scan_' + str(rep_num) +
                       '_green_blocked.tif')
        data_bg = np_tif.tif_to_array(filename_bg).astype(
            np.float64)[:, top:bot, :]
        print(filename)
        print(filename_bg)
        # apply registration shifts
        apply_registration_shifts(data,
                                  registration_shifts=[shifts[rep_num]] *
                                  data.shape[0],
                                  registration_type='nearest_integer',
                                  edges='sloppy')
        apply_registration_shifts(data_bg,
                                  registration_shifts=[shifts_bg[rep_num]] *
                                  data_bg.shape[0],
                                  registration_type='nearest_integer',
                                  edges='sloppy')
        # re-scale images to compensate for red beam brightness fluctuations
        # for regular data
        local_laser_brightness = get_bg_level(
            data.reshape(num_red_powers, num_green_powers, num_delays,
                         data.shape[-2], data.shape[-1]))
        local_calibration_factor = red_avg_brightness / local_laser_brightness
        local_calibration_factor = local_calibration_factor.reshape(
            num_red_powers * num_green_powers * num_delays, 1, 1)
        data = data * local_calibration_factor
        # for green blocked data
        local_laser_brightness_bg = get_bg_level(
            data_bg.reshape(num_red_powers, num_green_powers, num_delays,
                            data.shape[-2], data.shape[-1]))
        local_calibration_factor_bg = (red_avg_brightness /
                                       local_laser_brightness_bg)
        local_calibration_factor_bg = local_calibration_factor_bg.reshape(
            num_red_powers * num_green_powers * num_delays, 1, 1)
        data_bg = data_bg * local_calibration_factor_bg
        # draw rectangle around bright lobe and spatially average signal
        data_space_avg = data[:, rect_top:rect_bot,
                              rect_left:rect_right].mean(axis=2).mean(axis=1)
        data_bg_space_avg = data_bg[:, rect_top:rect_bot,
                                    rect_left:rect_right].mean(axis=2).mean(
                                        axis=1)
        # reshape 1D signal and place in output file
        signal[rep_num, :, :, :] = data_space_avg.reshape(
            num_red_powers, num_green_powers, num_delays)
        signal_bg[rep_num, :, :, :] = data_bg_space_avg.reshape(
            num_red_powers, num_green_powers, num_delays)
        # capture average images for max red/green power
        image_green_power = num_green_powers - 1
        image_red_power = num_red_powers - 1
        STE_image += data[-3,  # Zero delay, max red power, max green power
                          rep_top:rep_bot, rep_left:rep_right] / num_reps
        darkfield_image += data[
            -1,  # max red-green delay (2.5 us), max red power, max green power
            rep_top:rep_bot, rep_left:
            rep_right] / num_reps / 2  # one of two maximum absolute red/green delay values
        darkfield_image += data[
            -5,  # min red-green delay (-2.5 us), max red power, max green power
            rep_top:rep_bot, rep_left:
            rep_right] / num_reps / 2  # one of two maximum absolute red/green delay values
        STE_image_bg += data_bg[
            -3,  # Zero delay, max red power, max green power
            rep_top:rep_bot, rep_left:rep_right] / num_reps
        darkfield_image_bg += data_bg[
            -1,  # max red-green delay (2.5 us), max red power, max green power
            rep_top:rep_bot, rep_left:
            rep_right] / num_reps / 2  # one of two maximum absolute red/green delay values
        darkfield_image_bg += data_bg[
            -5,  # min red-green delay (-2.5 us), max red power, max green power
            rep_top:rep_bot, rep_left:
            rep_right] / num_reps / 2  # one of two maximum absolute red/green delay values

    print('Done applying shifts')

    signal_tif_shape = (signal.shape[0] * signal.shape[1], signal.shape[2],
                        signal.shape[3])

    print("Saving...")
    np_tif.array_to_tif(signal.reshape(signal_tif_shape),
                        'signal_all_scaled.tif')
    np_tif.array_to_tif(signal_bg.reshape(signal_tif_shape),
                        'signal_green_blocked_all_scaled.tif')
    np_tif.array_to_tif(darkfield_image, 'darkfield_image_avg.tif')
    np_tif.array_to_tif(darkfield_image_bg, 'darkfield_image_bg_avg.tif')
    np_tif.array_to_tif(STE_image, 'STE_image_avg.tif')
    np_tif.array_to_tif(STE_image_bg, 'STE_image_bg_avg.tif')
    print("... done.")

    return None
    expected_phases = np.zeros((len(shifts),) +
                               np.fft.rfft(stack[0, :, :]).shape)
    k_ud, k_lr = np.fft.fftfreq(stack.shape[1]), np.fft.rfftfreq(stack.shape[2])
    k_ud, k_lr = k_ud.reshape(k_ud.size, 1), k_lr.reshape(1, k_lr.size)
    for s, (y, x) in enumerate(shifts):
        top = max(0, y)
        lef = max(0, x)
        bot = min(obj.shape[-2], obj.shape[-2] + y)
        rig = min(obj.shape[-1], obj.shape[-1] + x)
        shifted_obj.fill(0)
        shifted_obj[0, top:bot, lef:rig] = obj[0, top-y:bot-y, lef-x:rig-x]
        stack[s, :, :] = np.random.poisson(
            brightness *
            bucket(gaussian_filter(shifted_obj, blur), bucket_size
                   )[0, crop:-crop, crop:-crop])
        expected_phases[s, :, :] = np.angle(np.fft.fftshift(
            expected_cross_power_spectrum((y/bucket_size[1], x/bucket_size[2]),
                                          k_ud, k_lr), axes=0))
    np_tif.array_to_tif(expected_phases, 'DEBUG_expected_phase_vs_ref.tif')
    np_tif.array_to_tif(stack, 'DEBUG_stack.tif')
    print(" Done.")
    print("Registering test stack...")
    calculated_shifts = stack_registration(
        stack,
        refinement='spike_interpolation',
        debug=True)
    print(" Done.")
    for s, cs in zip(shifts, calculated_shifts):
        print('%0.2f (%i)'%(cs[0] * bucket_size[1], s[0]),
              '%0.2f (%i)'%(cs[1] * bucket_size[2], s[1]))
Exemple #20
0
    ]

for gr_pow in green_powers:
    data_list = []
    for ang in angles:
        for rd_pow in red_powers:
            filename = (
                'STE_' +
                'phase_angle_' + ang +
                '_green' + gr_pow +
                '_red' + rd_pow +
                '.tif')
            assert os.path.exists(filename)
            print("Loading", filename)
            data = np_tif.tif_to_array(filename).astype(np.float64)
            assert data.shape == (3, 128, 380)
            data = data.reshape(1, data.shape[0], data.shape[1], data.shape[2])
            data_list.append(data)
    print("Done loading.")
    data = np.concatenate(data_list)
    print('data shape is',data.shape)
    tif_shape = (data.shape[0] * data.shape[1], data.shape[2], data.shape[3])
    
    


    print("Saving...")
    np_tif.array_to_tif(data.reshape(tif_shape),
                        ('dataset_green'+gr_pow+'_single_shot.tif'))
    print("Done saving.")
def stack_registration(
    s,
    align_to_this_slice=0,
    refinement='spike_interpolation',
    register_in_place=True,
    fourier_cutoff_radius=None,
    debug=False):
    """Calculate shifts which would register the slices of a
    three-dimensional stack `s`, and optionally register the stack in-place.

    Axis 0 is the "z-axis", axis 1 is the "up-down" (Y) axis, and axis 2
    is the "left-right" (X) axis. For each XY slice, we calculate the
    shift in the XY plane which would line that slice up with the slice
    specified by `align_to_this_slice`. If `align_to_this_slice` is a
    number, it indicates which slice of `s` to use as the reference
    slice. If `align_to_this_slice` is a numpy array, it is used as the
    reference slice, and must be the same shape as a 2D slice of `s`.

    `refinement` is one of `integer`, `spike_interpolation`, or
    `phase_fitting`, in order of increasing precision/slowness. I don't
    yet have any evidence that my implementation of phase fitting gives
    any improvement over (faster, simpler) spike interpolation, so
    caveat emptor.

    `register_in_place`: If `True`, modify the input stack `s` by
    shifting its slices to line up with the reference slice.

    `fourier_cutoff_radius`: Ignore the Fourier phases of spatial
    frequencies higher than this cutoff, since they're probably lousy
    due to aliasing and noise anyway. If `None`, attempt to estimate a
    resonable cutoff.
    """
    assert len(s.shape) == 3
    try:
        assert align_to_this_slice in range(s.shape[0])
        align_to_this_slice = s[align_to_this_slice, :, :]
    except ValueError:
        align_to_this_slice = np.squeeze(align_to_this_slice)
    assert align_to_this_slice.shape == s.shape[-2:]
    assert refinement in ('integer', 'spike_interpolation', 'phase_fitting')
    if refinement == 'phase_fitting' and minimize is None:
        raise UserWarning("Failed to import scipy minimize; no phase fitting.")
    assert register_in_place in (True, False)
    assert debug in (True, False)
    if fourier_cutoff_radius is None:
        fourier_cutoff_radius = estimate_fourier_cutoff_radius(s, debug)
    assert (0 < fourier_cutoff_radius <= 0.5)
    if debug and np_tif is None:
        raise UserWarning("Failed to import np_tif; no debug mode.")
    ## Multiply each slice of the stack by an XY mask that goes to zero
    ## at the edges, to prevent periodic boundary artifacts when we
    ## Fourier transform.
    mask_ud = np.sin(np.linspace(0, np.pi, s.shape[1])).reshape(s.shape[1], 1)
    mask_lr = np.sin(np.linspace(0, np.pi, s.shape[2])).reshape(1, s.shape[2])
    masked_reference_slice = align_to_this_slice * mask_ud * mask_lr
    ## We'll base our registration on the phase of the low spatial
    ## frequencies of the cross-power spectrum. We'll need the complex
    ## conjugate of the Fourier transform of the masked reference slice,
    ## and a mask in the Fourier domain to pick out the low spatial
    ## frequencies:
    ref_slice_ft_conj = np.conj(np.fft.rfftn(masked_reference_slice))
    k_ud = np.fft.fftfreq(s.shape[1]).reshape(ref_slice_ft_conj.shape[0], 1)
    k_lr = np.fft.rfftfreq(s.shape[2]).reshape(1, ref_slice_ft_conj.shape[1])
    fourier_mask = (k_ud**2 + k_lr**2) < (fourier_cutoff_radius)**2
    ## Now we'll loop over each slice of the stack, calculate our
    ## registration shifts, and optionally apply the shifts to the
    ## original stack.
    registration_shifts = []
    if debug:
        ## Save some intermediate data to help with debugging
        masked_stack = np.zeros_like(s)
        masked_stack_ft = np.zeros(
            (s.shape[0],) + ref_slice_ft_conj.shape, dtype=np.complex128)
        masked_stack_ft_vs_ref = np.zeros_like(masked_stack_ft)
        cross_power_spectra = np.zeros_like(masked_stack_ft)
        spikes = np.zeros(s.shape, dtype=np.float64)
    for which_slice in range(s.shape[0]):
        if debug: print("Calculating registration for slice", which_slice)
        ## Compute the cross-power spectrum of our slice, and mask out
        ## the high spatial frequencies.
        current_slice = s[which_slice, :, :] * mask_ud * mask_lr
        current_slice_ft = np.fft.rfftn(current_slice)
        cross_power_spectrum = current_slice_ft * ref_slice_ft_conj
        cross_power_spectrum = (fourier_mask *
                                cross_power_spectrum /
                                np.abs(cross_power_spectrum))
        ## Inverse transform to get a 'spike' in real space. The
        ## location of this spike gives the desired registration shift.
        ## Start by locating the spike to the nearest integer:
        spike = np.fft.irfftn(cross_power_spectrum, s=current_slice.shape)
        loc = np.array(np.unravel_index(np.argmax(spike), spike.shape))
        if refinement in ('spike_interpolation', 'phase_fitting'):
            ## Use (very simple) three-point polynomial interpolation to
            ## refine the location of the peak of the spike:
            neighbors = np.array([-1, 0, 1])
            ud_vals = spike[(loc[0] + neighbors) %spike.shape[0], loc[1]]
            lr_vals = spike[loc[0], (loc[1] + neighbors) %spike.shape[1]]
            lr_fit = np.poly1d(np.polyfit(neighbors, lr_vals, deg=2))
            ud_fit = np.poly1d(np.polyfit(neighbors, ud_vals, deg=2))
            lr_max_shift = -lr_fit[1] / (2 * lr_fit[2])
            ud_max_shift = -ud_fit[1] / (2 * ud_fit[2])
            loc = loc + (ud_max_shift, lr_max_shift)
        ## Convert our shift into a signed number near zero:
        loc = ((np.array(spike.shape)//2 + loc) % np.array(spike.shape)
               -np.array(spike.shape)//2)
        if refinement == 'phase_fitting':
            if debug: print("Phase fitting slice", which_slice, "...")
            ## (Attempt to) further refine our registration shift by
            ## fitting Fourier phases. I'm not sure this does any good,
            ## perhaps my implementation is lousy?
            def minimize_me(loc, cross_power_spectrum):
                disagreement = np.abs(
                    expected_cross_power_spectrum(loc, k_ud, k_lr) -
                    cross_power_spectrum
                    )[fourier_mask].sum()
                if debug: print(" Shift:", loc, "Disagreement:", disagreement)
                return disagreement
            loc = minimize(minimize_me,
                           x0=loc,
                           args=(cross_power_spectrum,),
                           method='Nelder-Mead').x
        registration_shifts.append(loc)
        if debug:
            ## Save some intermediate data to help with debugging
            masked_stack[which_slice, :, :] = current_slice
            masked_stack_ft[which_slice, :, :] = (
                np.fft.fftshift(current_slice_ft, axes=0))
            masked_stack_ft_vs_ref[which_slice, :, :] = (
                np.fft.fftshift(current_slice_ft * ref_slice_ft_conj, axes=0))
            cross_power_spectra[which_slice, :, :] = (
                np.fft.fftshift(cross_power_spectrum, axes=0))
            spikes[which_slice, :, :] = np.fft.fftshift(spike)
    if register_in_place:
        ## Modify the input stack in-place so it's registered.
        if refinement == 'integer':
            registration_type = 'nearest_integer'
        else:
            registration_type = 'fourier_interpolation'
        apply_registration_shifts(s, registration_shifts)
    if debug:
        np_tif.array_to_tif(masked_stack, 'DEBUG_masked_stack.tif')
        np_tif.array_to_tif(np.log(np.abs(masked_stack_ft)),
                            'DEBUG_masked_stack_FT_log_magnitudes.tif')
        np_tif.array_to_tif(np.angle(masked_stack_ft),
                            'DEBUG_masked_stack_FT_phases.tif')
        np_tif.array_to_tif(np.angle(masked_stack_ft_vs_ref),
                            'DEBUG_masked_stack_FT_phase_vs_ref.tif')
        np_tif.array_to_tif(np.angle(cross_power_spectra),
                            'DEBUG_cross_power_spectral_phases.tif')
        np_tif.array_to_tif(spikes, 'DEBUG_spikes.tif')
        if register_in_place:
            np_tif.array_to_tif(s, 'DEBUG_registered_stack.tif')
    return np.array(registration_shifts)
        sr.apply_registration_shifts(data[which_t, :, :, :],
                                     registration_shifts)
    print('done')
    # Registraction will often leave black borders so it's best to crop
    # at the end to tidy up.
    print('Cropping...')
    left_crop = border
    right_crop = border
    top_crop = border
    bottom_crop = border
    data = data[:, :, top_crop:-bottom_crop, left_crop:-right_crop]
    print('done')
    print("Saving result...", end='')
    np_tif.array_to_tif(data.reshape(num_tps * num_slices, data.shape[-2],
                                     data.shape[-1]),
                        registered_filename,
                        slices=num_slices,
                        channels=1,
                        frames=num_tps)
    print('done')
    print('tif shape (t, z, y, z) =', data.shape)

## Choose parameters for video
num_z_stack = 4
pause = 4
z_slow_down_factor = 1
z_scale = 10
current_frame = -1
current_tp = -1
xmargin = 0.15
ymargin = 0.725
space = 0.0275
        [-5, 0],
    ]
    bucket_size = (1, 4, 4)
    crop = 8
    print("Creating shifted stack from test object...")
    shifted_obj = np.zeros_like(obj)
    stack = np.zeros(
        (len(shifts) + 1, obj.shape[1] // bucket_size[1] - 2 * crop,
         obj.shape[2] // bucket_size[2] - 2 * crop))
    stack[0, :, :] = bucket(obj, bucket_size)[0, crop:-crop, crop:-crop]
    for s, (y, x) in enumerate(shifts):
        top = max(0, y)
        lef = max(0, x)
        bot = min(obj.shape[-2], obj.shape[-2] + y)
        rig = min(obj.shape[-1], obj.shape[-1] + x)
        shifted_obj.fill(0)
        shifted_obj[0, top:bot, lef:rig] = obj[0, top - y:bot - y,
                                               lef - x:rig - x]
        stack[s + 1, :, :] = bucket(shifted_obj, bucket_size)[0, crop:-crop,
                                                              crop:-crop]
    np_tif.array_to_tif(stack, 'DEBUG_stack.tif')
    print(" Done.")
    print("Registering test stack...")
    shifts = stack_registration(stack,
                                refinement='spike_interpolation',
                                debug=True)
    print(" Done.")
    for s in shifts:
        print(s)
    np_tif.array_to_tif(stack, 'DEBUG_stack_registered.tif')