Exemple #1
0
def CrossCorrelation(kernel,
                     mode='circular',
                     dtype=None,
                     backend=None,
                     label='X',
                     pad_value=0,
                     axis=None,
                     pad_convolution=True,
                     pad_fft=True,
                     invalid_support_value=1,
                     fft_backend=None):

    # Flip kernel
    kernel = iFt(conj(Ft(kernel)))

    # Call Convolution
    return Convolution(kernel,
                       mode=mode,
                       dtype=dtype,
                       backend=backend,
                       label=label,
                       pad_value=pad_value,
                       axis=None,
                       pad_convolution=pad_convolution,
                       pad_fft=pad_fft,
                       invalid_support_value=invalid_support_value,
                       fft_backend=fft_backend)
Exemple #2
0
    def show_xc(x, figsize=(11, 3)):
        xf_1 = F * _R.arguments[0]
        xf_2 = F * x

        # Compute normalized cross-correlation
        phasor = (conj(xf_1) * xf_2) / abs(conj(xf_1) * xf_2)

        import matplotlib.pyplot as plt
        import llops as yp

        plt.figure(figsize=figsize)
        plt.subplot(121)
        plt.imshow(yp.angle(phasor))
        plt.title('Phase of Frequency domain')
        plt.subplot(122)
        plt.imshow(yp.abs(yp.iFt(phasor)))
        plt.title('Amplitude of Object domain')
Exemple #3
0
    def blur_vectors(self, dtype=None, backend=None, debug=False,
                     use_phase_ramp=False, corrections={}):
        """
        This function generates the object size, image size, and blur kernels from
        a libwallerlab dataset object.

            Args:
                dataset: An io.Dataset object
                dtype [np.float32]: Which datatype to use for kernel generation (All numpy datatypes supported)
            Returns:
                object_size: The object size this dataset can recover
                image_size: The computed image size of the dataset
                blur_kernel_list: A dictionary of blur kernels lists, one key per color channel.

        """
        # Assign dataset
        dataset = self

        # Get corrections from metadata
        if len(corrections) is 0 and 'blur_vector' in self.metadata.calibration:
            corrections = dataset.metadata.calibration['blur_vector']

        # Get datatype and backends
        dtype = dtype if dtype is not None else yp.config.default_dtype
        backend = backend if backend is not None else yp.config.default_backend

        # Calculate effective pixel size if necessaey
        if dataset.metadata.system.eff_pixel_size_um is None:
            dataset.metadata.system.eff_pixel_size_um = dataset.metadata.camera.pixel_size_um / \
                (dataset.metadata.objective.mag * dataset.metadata.system.mag)

        # Recover and store position and illumination list
        blur_vector_roi_list = []
        position_list, illumination_list = [], []
        frame_segment_map = []

        for frame_index in range(dataset.shape[0]):
            frame_state = dataset.frame_state_list[frame_index]

            # Store which segment this measurement uses
            frame_segment_map.append(frame_state['position']['common']['linear_segment_index'])

            # Extract list of illumination values for each time point
            if 'illumination' in frame_state:
                illumination_list_frame = []
                if type(frame_state['illumination']) is str:
                    illum_state_list = self._frame_state_list[0]['illumination']['states']
                else:
                    illum_state_list = frame_state['illumination']['states']
                for time_point in illum_state_list:
                    illumination_list_time_point = []
                    for illumination in time_point:
                        illumination_list_time_point.append(
                            {'index': illumination['index'], 'value': illumination['value']})
                    illumination_list_frame.append(illumination_list_time_point)

            else:
                raise ValueError('Frame %d does not contain illumination information' % frame_index)

            # Extract list of positions for each time point
            if 'position' in frame_state:
                position_list_frame = []
                for time_point in frame_state['position']['states']:
                    position_list_time_point = []
                    for position in time_point:
                        if 'units' in position['value']:
                            if position['value']['units'] == 'mm':
                                ps_um = dataset.metadata.system.eff_pixel_size_um
                                position_list_time_point.append(
                                    [1000 * position['value']['y'] / ps_um, 1000 * position['value']['x'] / ps_um])
                            elif position['value']['units'] == 'um':
                                position_list_time_point.append(
                                    [position['value']['y'] / ps_um, position['value']['x'] / ps_um])
                            elif position['value']['units'] == 'pixels':
                                position_list_time_point.append([position['value']['y'], position['value']['x']])
                            else:
                                raise ValueError('Invalid units %s for position in frame %d' %
                                                 (position['value']['units'], frame_index))
                        else:
                            # print('WARNING: Could not find posiiton units in metadata, assuming mm')
                            ps_um = dataset.metadata.system.eff_pixel_size_um
                            position_list_time_point.append(
                                [1000 * position['value']['y'] / ps_um, 1000 * position['value']['x'] / ps_um])

                    position_list_frame.append(position_list_time_point[0])  # Assuming single time point for now.

                # Define positions and position indicies used
                positions_used, position_indicies_used = [], []
                for index, pos in enumerate(position_list_frame):
                    for color in illumination_list_frame[index][0]['value']:
                        if any([illumination_list_frame[index][0]['value'][color] > 0 for color in illumination_list_frame[index][0]['value']]):
                            position_indicies_used.append(index)
                            positions_used.append(pos)

                # Generate ROI for this blur vector
                from htdeblur.blurkernel import getPositionListBoundingBox
                blur_vector_roi = getPositionListBoundingBox(positions_used)

                # Append to list
                blur_vector_roi_list.append(blur_vector_roi)

                # Crop illumination list to values within the support used
                illumination_list.append([illumination_list_frame[index] for index in range(min(position_indicies_used), max(position_indicies_used) + 1)])

                # Store corresponding positions
                position_list.append(positions_used)

        # Apply kernel scaling or compression if necessary
        if 'scale' in corrections:

            # We need to use phase-ramp based kernel generation if we modify the positions
            use_phase_ramp = True

            # Modify position list
            for index in range(len(position_list)):
                _positions = np.asarray(position_list[index])
                for scale_correction in corrections['scale']:
                    factor, axis = corrections['scale']['factor'], corrections['scale']['axis']
                    _positions[:, axis] = ((_positions[:, axis] - yp.min(_positions[:, axis])) * factor + yp.min(_positions[:, axis]))
                position_list[index] = _positions.tolist()

        # Synthesize blur vectors
        blur_vector_list = []
        for frame_index in range(dataset.shape[0]):
            #  Generate blur vectors
            if use_phase_ramp:
                from llops.operators import PhaseRamp
                kernel_shape = [yp.fft.next_fast_len(max(sh, 1)) for sh in blur_vector_roi_list[frame_index].shape]
                offset = yp.cast([sh // 2 + st for (sh, st) in zip(kernel_shape, blur_vector_roi_list[frame_index].start)], 'complex32', dataset.backend)

                # Create phase ramp and calculate offset
                R = PhaseRamp(kernel_shape, dtype='complex32', backend=dataset.backend)

                # Generate blur vector
                blur_vector = yp.zeros(R.M, dtype='complex32', backend=dataset.backend)
                for pos, illum in zip(position_list[frame_index], illumination_list[frame_index]):
                    pos = yp.cast(pos, dtype=dataset.dtype, backend=dataset.backend)
                    blur_vector += (R * (yp.cast(pos - offset, 'complex32')))

                # Take inverse Fourier Transform
                blur_vector = yp.abs(yp.cast(yp.iFt(blur_vector)), 0.0)

                if position_list[frame_index][0][-1] > position_list[frame_index][0][0]:
                    blur_vector = yp.flip(blur_vector)

            else:
                blur_vector = yp.asarray([illum[0]['value']['w'] for illum in illumination_list[frame_index]],
                                         dtype=dtype, backend=backend)

            # Normalize illuminaiton vectors
            blur_vector /= yp.scalar(yp.sum(blur_vector))

            # Append to list
            blur_vector_list.append(blur_vector)

        # Return
        return blur_vector_list, blur_vector_roi_list
Exemple #4
0
def registerImage(image0,
                  image1,
                  method='xc',
                  axis=None,
                  preprocess_methods=['reflect'],
                  debug=False,
                  **kwargs):

    # Perform preprocessing
    if len(preprocess_methods) > 0:
        image0, image1 = _preprocessForRegistration(image0, image1,
                                                    preprocess_methods,
                                                    **kwargs)

    # Parameter on whether we can trust our registration
    trust_ratio = 1.0

    if method in ['xc' or 'cross_correlation']:

        # Get energy ratio threshold
        trust_threshold = kwargs.get('energy_ratio_threshold', 1.5)

        # Pad arrays for optimal speed
        pad_size = tuple(
            [sp.fftpack.next_fast_len(s) for s in yp.shape(image0)])

        # Perform padding
        if pad_size is not yp.shape(image0):
            image0 = yp.pad(image0, pad_size, pad_value='edge', center=True)
            image1 = yp.pad(image1, pad_size, pad_value='edge', center=True)

        # Take F.T. of measurements
        src_freq, target_freq = yp.Ft(image0, axes=axis), yp.Ft(image1,
                                                                axes=axis)

        # Whole-pixel shift - Compute cross-correlation by an IFFT
        image_product = src_freq * yp.conj(target_freq)
        # image_product /= abs(src_freq * yp.conj(target_freq))
        cross_correlation = yp.iFt(image_product, center=False, axes=axis)

        # Take sum along axis if we're doing 1D
        if axis is not None:
            axis_to_sum = list(range(yp.ndim(image1)))
            del axis_to_sum[axis]
            cross_correlation = yp.sum(cross_correlation, axis=axis_to_sum)

        # Locate maximum
        shape = yp.shape(src_freq)
        maxima = yp.argmax(yp.abs(cross_correlation))
        midpoints = np.array([np.fix(axis_size / 2) for axis_size in shape])

        shifts = np.array(maxima, dtype=np.float64)
        shifts[shifts > midpoints] -= np.array(shape)[shifts > midpoints]

        # If its only one row or column the shift along that dimension has no
        # effect. We set to zero.
        for dim in range(yp.ndim(src_freq)):
            if shape[dim] == 1:
                shifts[dim] = 0

        # If energy ratio is too small, set all shifts to zero
        trust_metric = yp.scalar(
            yp.max(yp.abs(cross_correlation)**2) /
            yp.mean(yp.abs(cross_correlation)**2))

        # Determine if this registraition can be trusted
        trust_ratio = trust_metric / trust_threshold

    elif method == 'orb':

        # Get user-defined mean_residual_threshold if given
        trust_threshold = kwargs.get('mean_residual_threshold', 40.0)

        # Get user-defined mean_residual_threshold if given
        orb_feature_threshold = kwargs.get('orb_feature_threshold', 25)

        match_count = 0
        fast_threshold = 0.05
        while match_count < orb_feature_threshold:
            descriptor_extractor = ORB(n_keypoints=500,
                                       fast_n=9,
                                       harris_k=0.1,
                                       fast_threshold=fast_threshold)

            # Extract keypoints from first frame
            descriptor_extractor.detect_and_extract(
                np.asarray(image0).astype(np.double))
            keypoints0 = descriptor_extractor.keypoints
            descriptors0 = descriptor_extractor.descriptors

            # Extract keypoints from second frame
            descriptor_extractor.detect_and_extract(
                np.asarray(image1).astype(np.double))
            keypoints1 = descriptor_extractor.keypoints
            descriptors1 = descriptor_extractor.descriptors

            # Set match count
            match_count = min(len(keypoints0), len(keypoints1))
            fast_threshold -= 0.01

            if fast_threshold == 0:
                raise RuntimeError(
                    'Could not find any keypoints (even after shrinking fast threshold).'
                )

        # Match descriptors
        matches = match_descriptors(descriptors0,
                                    descriptors1,
                                    cross_check=True)

        # Filter descriptors to axes (if provided)
        if axis is not None:
            matches_filtered = []
            for (index_0, index_1) in matches:
                point_0 = keypoints0[index_0, :]
                point_1 = keypoints1[index_1, :]
                unit_vec = point_0 - point_1
                unit_vec /= np.linalg.norm(unit_vec)

                if yp.abs(unit_vec[axis]) > 0.99:
                    matches_filtered.append((index_0, index_1))

            matches_filtered = np.asarray(matches_filtered)
        else:
            matches_filtered = matches

        # Robustly estimate affine transform model with RANSAC
        model_robust, inliers = ransac((keypoints0[matches_filtered[:, 0]],
                                        keypoints1[matches_filtered[:, 1]]),
                                       EuclideanTransform,
                                       min_samples=3,
                                       residual_threshold=2,
                                       max_trials=100)

        # Note that model_robust has a translation property, but this doesn't
        # seem to be as numerically stable as simply averaging the difference
        # between the coordinates along the desired axis.

        # Apply match filter
        matches_filtered = matches_filtered[inliers, :]

        # Process keypoints
        if yp.shape(matches_filtered)[0] > 0:

            # Compute shifts
            difference = keypoints0[matches_filtered[:, 0]] - keypoints1[
                matches_filtered[:, 1]]
            shifts = (yp.sum(difference, axis=0) / yp.shape(difference)[0])
            shifts = np.round(shifts[0])

            # Filter to axis mask
            if axis is not None:
                _shifts = [0, 0]
                _shifts[axis] = shifts[axis]
                shifts = _shifts

            # Calculate residuals
            residuals = yp.sqrt(
                yp.sum(
                    yp.abs(keypoints0[matches_filtered[:, 0]] +
                           np.asarray(shifts) -
                           keypoints1[matches_filtered[:, 1]])**2))

            # Define a trust metric
            trust_metric = residuals / yp.shape(
                keypoints0[matches_filtered[:, 0]])[0]

            # Determine if this registration can be trusted
            trust_ratio = 1 / (trust_metric / trust_threshold)
            print('===')
            print(trust_ratio)
            print(trust_threshold)
            print(trust_metric)
            print(shifts)
        else:
            trust_metric = 1e10
            trust_ratio = 0.0
            shifts = np.asarray([0, 0])

    elif method == 'optimize':

        # Create Operators
        L2 = ops.L2Norm(yp.shape(image0), dtype='complex64')
        R = ops.PhaseRamp(yp.shape(image0), dtype='complex64')
        REAL = ops.RealFilter((2, 1), dtype='complex64')

        # Take Fourier Transforms of images
        image0_f, image1_f = yp.astype(yp.Ft(image0), 'complex64'), yp.astype(
            yp.Ft(image1), 'complex64')

        # Diagonalize one of the images
        D = ops.Diagonalize(image0_f)

        # Form objective
        objective = L2 * (D * R * REAL - image1_f)

        # Solve objective
        solver = ops.solvers.GradientDescent(objective)
        shifts = solver.solve(iteration_count=1000, step_size=1e-8)

        # Convert to numpy array, take real part, and round.
        shifts = yp.round(yp.real(yp.asbackend(shifts, 'numpy')))

        # Flip shift axes (x,y to y, x)
        shifts = np.fliplr(shifts)

        # TODO: Trust metric and trust_threshold
        trust_threshold = 1
        trust_ratio = 1.0

    else:
        raise ValueError('Invalid Registration Method %s' % method)

    # Mark whether or not this measurement is of good quality
    if not trust_ratio > 1:
        if debug:
            print('Ignoring shift with trust metric %g (threshold is %g)' %
                  (trust_metric, trust_threshold))
        shifts = yp.zeros_like(np.asarray(shifts)).tolist()

    # Show debugging figures if requested
    if debug:
        import matplotlib.pyplot as plt
        plt.figure(figsize=(6, 5))
        plt.subplot(131)
        plt.imshow(yp.abs(image0))
        plt.axis('off')
        plt.subplot(132)
        plt.imshow(yp.abs(image1))
        plt.title('Trust ratio: %g' % (trust_ratio))
        plt.axis('off')
        plt.subplot(133)
        if method in ['xc' or 'cross_correlation']:
            if axis is not None:
                plt.plot(yp.abs(yp.squeeze(cross_correlation)))
            else:
                plt.imshow(yp.abs(yp.fftshift(cross_correlation)))
        else:
            plot_matches(plt.gca(), yp.real(image0), yp.real(image1),
                         keypoints0, keypoints1, matches_filtered)
        plt.title(str(shifts))
        plt.axis('off')

    # Return
    return shifts, trust_ratio
Exemple #5
0
def register_translation(src_image,
                         target_image,
                         upsample_factor=1,
                         energy_ratio_threshold=2,
                         space="real"):
    """
    Efficient subpixel image translation registration by cross-correlation.
    This code gives the same precision as the FFT upsampled cross-correlation
    in a fraction of the computation time and with reduced memory requirements.
    It obtains an initial estimate of the cross-correlation peak by an FFT and
    then refines the shift estimation by upsampling the DFT only in a small
    neighborhood of that estimate by means of a matrix-multiply DFT.
    Parameters
    ----------
    src_image : ndarray
        Reference image.
    target_image : ndarray
        Image to register.  Must be same dimensionality as ``src_image``.
    upsample_factor : int, optional
        Upsampling factor. Images will be registered to within
        ``1 / upsample_factor`` of a pixel. For example
        ``upsample_factor == 20`` means the images will be registered
        within 1/20th of a pixel.  Default is 1 (no upsampling)
    space : string, one of "real" or "fourier", optional
        Defines how the algorithm interprets input data.  "real" means data
        will be FFT'd to compute the correlation, while "fourier" data will
        bypass FFT of input data.  Case insensitive.
    return_error : bool, optional
        Returns error and phase difference if on,
        otherwise only shifts are returned
    Returns
    -------
    shifts : ndarray
        Shift vector (in pixels) required to register ``target_image`` with
        ``src_image``.  Axis ordering is consistent with numpy (e.g. Z, Y, X)
    error : float
        Translation invariant normalized RMS error between ``src_image`` and
        ``target_image``.
    phasediff : float
        Global phase difference between the two images (should be
        zero if images are non-negative).
    References
    ----------
    .. [1] Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup,
           "Efficient subpixel image registration algorithms,"
           Optics Letters 33, 156-158 (2008). :DOI:`10.1364/OL.33.000156`
    .. [2] James R. Fienup, "Invariant error metrics for image reconstruction"
           Optics Letters 36, 8352-8357 (1997). :DOI:`10.1364/AO.36.008352`
    """
    # images must be the same shape
    if yp.shape(src_image) != yp.shape(target_image):
        raise ValueError("Error: images must be same size for "
                         "register_translation")

    # only 2D data makes sense right now
    if yp.ndim(src_image) > 3 and upsample_factor > 1:
        raise NotImplementedError("Error: register_translation only supports "
                                  "subpixel registration for 2D and 3D images")

    # assume complex data is already in Fourier space
    if space.lower() == 'fourier':
        src_freq = src_image
        target_freq = target_image
    # real data needs to be fft'd.
    elif space.lower() == 'real':
        src_freq = yp.Ft(src_image)
        target_freq = yp.Ft(target_image)
    else:
        raise ValueError("Error: register_translation only knows the \"real\" "
                         "and \"fourier\" values for the ``space`` argument.")

    # Whole-pixel shift - Compute cross-correlation by an IFFT
    shape = yp.shape(src_freq)
    image_product = src_freq * yp.conj(target_freq)
    cross_correlation = yp.iFt(image_product, center=False)

    # Locate maximum
    maxima = yp.argmax(yp.abs(cross_correlation))
    midpoints = np.array([np.fix(axis_size / 2) for axis_size in shape])

    shifts = np.array(maxima, dtype=np.float64)
    shifts[shifts > midpoints] -= np.array(shape)[shifts > midpoints]

    # if upsample_factor > 1:
    #     # Initial shift estimate in upsampled grid
    #     shifts = np.round(shifts * upsample_factor) / upsample_factor
    #     upsampled_region_size = np.ceil(upsample_factor * 1.5)
    #     # Center of output array at dftshift + 1
    #     dftshift = np.fix(upsampled_region_size / 2.0)
    #     upsample_factor = np.array(upsample_factor, dtype=np.float64)
    #     normalization = (src_freq.size * upsample_factor ** 2)
    #     # Matrix multiply DFT around the current shift estimate
    #     sample_region_offset = dftshift - shifts*upsample_factor
    #     cross_correlation = _upsampled_dft(image_product.conj(),
    #                                        upsampled_region_size,
    #                                        upsample_factor,
    #                                        sample_region_offset).conj()
    #     cross_correlation /= normalization
    #     # Locate maximum and map back to original pixel grid
    #     maxima = np.array(np.unravel_index(
    #                           np.argmax(np.abs(cross_correlation)),
    #                           cross_correlation.shape),
    #                       dtype=np.float64)
    #     maxima -= dftshift
    #
    #     shifts = shifts + maxima / upsample_factor

    # If its only one row or column the shift along that dimension has no
    # effect. We set to zero.
    for dim in range(yp.ndim(src_freq)):
        if shape[dim] == 1:
            shifts[dim] = 0

    # If energy ratio is too small, set all shifts to zero
    energy_ratio = yp.max(yp.abs(cross_correlation)**2) / yp.sum(
        yp.abs(cross_correlation)**2) * yp.prod(yp.shape(cross_correlation))
    if energy_ratio < energy_ratio_threshold:
        print('Ignoring shift with energy ratio %g (threshold is %g)' %
              (energy_ratio, energy_ratio_threshold))
        shifts = yp.zeros_like(shifts)

    return shifts