Example #1
0
        def _centering(image, popt):

            if self.m_method == "full":
                popt = _least_squares(np.copy(image))

            return shift_image(image, (-popt[1], -popt[0]),
                               self.m_interpolation)
Example #2
0
    def run(self) -> None:
        """
        Run method of the module. Shifts an image with a fifth order spline, bilinear, or a
        Fourier shift interpolation.

        Returns
        -------
        NoneType
            None
        """

        constant = True

        # read the fit results from the self.m_fit_in_port if available
        if self.m_fit_in_port is not None:

            self.m_shift = -1. * self.m_fit_in_port[:, [0, 2]]  # (x, y)
            self.m_shift = self.m_shift[:, [1, 0]]  # (y, x)

            # check if data in self.m_fit_in_port is constant for all images using the
            # constant flag
            if not np.allclose(
                    self.m_fit_in_port.get_all() - self.m_fit_in_port[0, ],
                    0.0):
                constant = False

            if constant:
                # if the offset is constant then use the first element for all images
                self.m_shift = self.m_shift[0, ]

            else:
                # if the offset is not constant, then apply the shifts to each frame individually
                for i, shift in enumerate(self.m_shift):
                    shifted_image = shift_image(self.m_image_in_port[i, ],
                                                shift, self.m_interpolation)

                    # append the shifted images to the selt.m_image_out_port database entry
                    self.m_image_out_port.append(shifted_image, data_dim=3)

                mean_shift = np.mean(self.m_shift, axis=0)
                history = f'shift_xy = {mean_shift[0]:.2f}, {mean_shift[1]:.2f}'

        # apply a constant shift
        if constant:

            self.apply_function_to_images(shift_image,
                                          self.m_image_in_port,
                                          self.m_image_out_port,
                                          'Shifting the images',
                                          func_args=(self.m_shift,
                                                     self.m_interpolation))

            # if self.m_fit_in_port is None or constant:
            history = f'shift_xy = {self.m_shift[0]:.2f}, {self.m_shift[1]:.2f}'

        self.m_image_out_port.copy_attributes(self.m_image_in_port)
        self.m_image_out_port.add_history('ShiftImagesModule', history)
        self.m_image_out_port.close_port()
Example #3
0
    def run(self) -> None:
        """
        Run method of the module. Normalizes the images for the different filter widths,
        upscales the images, and crops the images to the initial image shape in order to
        align the PSF patterns.

        Returns
        -------
        NoneType
            None
        """

        self.m_image_out_port.del_all_data()
        self.m_image_out_port.del_all_attributes()

        wvl_factor = self.m_line_wvl / self.m_cnt_wvl
        width_factor = self.m_line_width / self.m_cnt_width

        nimages = self.m_image_in_port.get_shape()[0]

        start_time = time.time()
        for i in range(nimages):
            progress(i, nimages, 'Running SDIpreparationModule...', start_time)

            image = self.m_image_in_port[i, ]

            im_scale = width_factor * scale_image(image, wvl_factor,
                                                  wvl_factor)

            if i == 0:
                npix_del = im_scale.shape[-1] - image.shape[-1]

                if npix_del % 2 == 0:
                    npix_del_a = int(npix_del / 2)
                    npix_del_b = int(npix_del / 2)

                else:
                    npix_del_a = int((npix_del - 1) / 2)
                    npix_del_b = int((npix_del + 1) / 2)

            im_crop = im_scale[npix_del_a:-npix_del_b, npix_del_a:-npix_del_b]

            if npix_del % 2 == 1:
                im_crop = shift_image(im_crop, (-0.5, -0.5),
                                      interpolation='spline')

            self.m_image_out_port.append(im_crop, data_dim=3)

        sys.stdout.write('Running SDIpreparationModule... [DONE]\n')
        sys.stdout.flush()

        history = f'(line, continuum) = ({self.m_line_wvl}, {self.m_cnt_wvl})'
        self.m_image_out_port.copy_attributes(self.m_image_in_port)
        self.m_image_out_port.add_history('SDIpreparationModule', history)
        self.m_image_in_port.close_port()
Example #4
0
def align_image(image_in: np.ndarray,
                im_index: int,
                interpolation: str,
                accuracy: float,
                resize: Optional[float],
                num_references: int,
                subframe: Optional[float],
                ref_images_reshape: np.ndarray,
                ref_images_shape: Tuple[int, int, int]) -> np.ndarray:

    offset = np.array([0., 0.])

    # Reshape the reference images back to their original 3D shape
    # The original shape can not be used directly because of util.module.update_arguments
    ref_images = ref_images_reshape.reshape(ref_images_shape)

    for i in range(num_references):
        if subframe is None:
            tmp_offset, _, _ = phase_cross_correlation(ref_images[i, :, :],
                                                       image_in,
                                                       upsample_factor=accuracy)

        else:
            sub_in = crop_image(image_in, None, subframe)
            sub_ref = crop_image(ref_images[i, :, :], None, subframe)

            tmp_offset, _, _ = phase_cross_correlation(sub_ref,
                                                       sub_in,
                                                       upsample_factor=accuracy)
        offset += tmp_offset

    offset /= float(num_references)

    if resize is not None:
        offset *= resize

        sum_before = np.sum(image_in)

        tmp_image = rescale(image_in,
                            (resize, resize),
                            order=5,
                            mode='reflect',
                            multichannel=False,
                            anti_aliasing=True)

        sum_after = np.sum(tmp_image)

        # Conserve flux because the rescale function normalizes all values to [0:1].
        tmp_image = tmp_image*(sum_before/sum_after)

    else:
        tmp_image = image_in

    return shift_image(tmp_image, offset, interpolation)
Example #5
0
def sdi_scaling(image_in: np.ndarray, scaling: np.ndarray) -> np.ndarray:
    """
    Function to rescale the images by their wavelength ratios.

    Parameters
    ----------
    image_in : np.ndarray
        Data to rescale
    scaling : np.ndarray
        Scaling factors.

    Returns
    -------
    np.ndarray
        Rescaled images with the same shape as ``image_in``.
    """

    if image_in.shape[0] != scaling.shape[0]:
        raise ValueError(
            'The number of wavelengths is not equal to the number of available '
            'scaling factors.')

    image_out = np.zeros(image_in.shape)

    for i in range(image_in.shape[0]):
        swaps = scale_image(image_in[i, ], scaling[i], scaling[i])

        npix_del = swaps.shape[-1] - image_out.shape[-1]

        if npix_del == 0:
            image_out[i, ] = swaps

        else:
            if npix_del % 2 == 0:
                npix_del_a = int(npix_del / 2)
                npix_del_b = int(npix_del / 2)

            else:
                npix_del_a = int((npix_del - 1) / 2)
                npix_del_b = int((npix_del + 1) / 2)

            image_out[i, ] = swaps[npix_del_a:-npix_del_b,
                                   npix_del_a:-npix_del_b]

        if npix_del % 2 == 1:
            image_out[i, ] = shift_image(image_out[i, ], (-0.5, -0.5),
                                         interpolation='spline')

    return image_out
Example #6
0
        def _align_image(image_in):
            offset = np.array([0., 0.])

            for i in range(self.m_num_references):
                if self.m_subframe is None:
                    tmp_offset, _, _ = register_translation(
                        ref_images[i, :, :],
                        image_in,
                        upsample_factor=self.m_accuracy)

                else:
                    sub_in = crop_image(image_in, None, self.m_subframe)
                    sub_ref = crop_image(ref_images[i, :, :], None,
                                         self.m_subframe)

                    tmp_offset, _, _ = register_translation(
                        sub_ref, sub_in, upsample_factor=self.m_accuracy)
                offset += tmp_offset

            offset /= float(self.m_num_references)

            if self.m_resize is not None:
                offset *= self.m_resize

                sum_before = np.sum(image_in)

                tmp_image = rescale(image=np.asarray(image_in,
                                                     dtype=np.float64),
                                    scale=(self.m_resize, self.m_resize),
                                    order=5,
                                    mode='reflect',
                                    anti_aliasing=True,
                                    multichannel=False)

                sum_after = np.sum(tmp_image)

                # Conserve flux because the rescale function normalizes all values to [0:1].
                tmp_image = tmp_image * (sum_before / sum_after)

            else:
                tmp_image = image_in

            return shift_image(tmp_image, offset, self.m_interpolation)
Example #7
0
def apply_shift(image_in: np.ndarray,
                im_index: int,
                shift: Union[Tuple[float, float], np.ndarray],
                interpolation: str) -> np.ndarray:

    return shift_image(image_in, shift, interpolation)
Example #8
0
    def run(self) -> None:
        """
        Run method of the module. Locates the position of the calibration spots in the center
        frame. From the four spots, the position of the star behind the coronagraph is fitted,
        and the images are shifted and cropped.

        Returns
        -------
        NoneType
            None
        """
        def _get_center(center):
            center_frame = self.m_center_in_port[0, ]

            if center_shape[0] > 1:
                warnings.warn(
                    'Multiple center images found. Using the first image of the stack.'
                )

            if center is None:
                center = center_pixel(center_frame)
            else:
                center = (np.floor(center[0]), np.floor(center[1]))

            return center_frame, center

        self.m_image_out_port.del_all_data()
        self.m_image_out_port.del_all_attributes()

        center_shape = self.m_center_in_port.get_shape()
        im_shape = self.m_image_in_port.get_shape()

        center_frame, self.m_center = _get_center(self.m_center)

        if im_shape[-2:] != center_shape[-2:]:
            raise ValueError(
                'Science and center images should have the same shape.')

        pixscale = self.m_image_in_port.get_attribute('PIXSCALE')

        self.m_sigma /= pixscale

        if self.m_size is not None:
            self.m_size = int(math.ceil(self.m_size / pixscale))

        if self.m_dither:
            dither_x = self.m_image_in_port.get_attribute('DITHER_X')
            dither_y = self.m_image_in_port.get_attribute('DITHER_Y')

            nframes = self.m_image_in_port.get_attribute('NFRAMES')
            nframes = np.cumsum(nframes)
            nframes = np.insert(nframes, 0, 0)

        center_frame_unsharp = center_frame - gaussian_filter(
            input=center_frame, sigma=self.m_sigma)

        # size of center image, only works with odd value
        ref_image_size = 21

        # Arrays for the positions
        x_pos = np.zeros(4)
        y_pos = np.zeros(4)

        # Loop for 4 waffle spots
        for i in range(4):
            # Approximate positions of waffle spots
            if self.m_pattern == 'x':
                x_0 = np.floor(self.m_center[0] +
                               self.m_radius * np.cos(np.pi / 4. *
                                                      (2 * i + 1)))
                y_0 = np.floor(self.m_center[1] +
                               self.m_radius * np.sin(np.pi / 4. *
                                                      (2 * i + 1)))

            elif self.m_pattern == '+':
                x_0 = np.floor(self.m_center[0] +
                               self.m_radius * np.cos(np.pi / 4. * (2 * i)))
                y_0 = np.floor(self.m_center[1] +
                               self.m_radius * np.sin(np.pi / 4. * (2 * i)))

            tmp_center_frame = crop_image(image=center_frame_unsharp,
                                          center=(int(y_0), int(x_0)),
                                          size=ref_image_size)

            # find maximum in tmp image
            coords = np.unravel_index(indices=np.argmax(tmp_center_frame),
                                      shape=tmp_center_frame.shape)

            y_max, x_max = coords[0], coords[1]

            pixmax = tmp_center_frame[y_max, x_max]
            max_pos = np.array([x_max, y_max]).reshape(1, 2)

            # Check whether it is the correct maximum: second brightest pixel should be nearby
            tmp_center_frame[y_max, x_max] = 0.

            # introduce distance parameter
            dist = np.inf

            while dist > 2:
                coords = np.unravel_index(indices=np.argmax(tmp_center_frame),
                                          shape=tmp_center_frame.shape)

                y_max_new, x_max_new = coords[0], coords[1]

                pixmax_new = tmp_center_frame[y_max_new, x_max_new]

                # Caculate minimal distance to previous points
                tmp_center_frame[y_max_new, x_max_new] = 0.

                dist = np.amin(
                    np.linalg.norm(np.vstack((max_pos[:, 0] - x_max_new,
                                              max_pos[:, 1] - y_max_new)),
                                   axis=0))

                if dist <= 2 and pixmax_new < pixmax:
                    break

                max_pos = np.vstack((max_pos, [x_max_new, y_max_new]))

                x_max = x_max_new
                y_max = y_max_new
                pixmax = pixmax_new

            x_0 = x_0 - (ref_image_size - 1) / 2 + x_max
            y_0 = y_0 - (ref_image_size - 1) / 2 + y_max

            # create reference image around determined maximum
            ref_center_frame = crop_image(image=center_frame_unsharp,
                                          center=(int(y_0), int(x_0)),
                                          size=ref_image_size)

            # Fit the data using astropy.modeling
            gauss_init = models.Gaussian2D(amplitude=np.amax(ref_center_frame),
                                           x_mean=x_0,
                                           y_mean=y_0,
                                           x_stddev=1.,
                                           y_stddev=1.,
                                           theta=0.)

            fit_gauss = fitting.LevMarLSQFitter()

            y_grid, x_grid = np.mgrid[y_0 - (ref_image_size - 1) / 2:y_0 +
                                      (ref_image_size - 1) / 2 + 1,
                                      x_0 - (ref_image_size - 1) / 2:x_0 +
                                      (ref_image_size - 1) / 2 + 1]

            gauss = fit_gauss(gauss_init, x_grid, y_grid, ref_center_frame)

            x_pos[i] = gauss.x_mean.value
            y_pos[i] = gauss.y_mean.value

        # Find star position as intersection of two lines

        x_center = ((y_pos[0]-x_pos[0]*(y_pos[2]-y_pos[0])/(x_pos[2]-float(x_pos[0]))) -
                    (y_pos[1]-x_pos[1]*(y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])))) / \
                   ((y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])) -
                    (y_pos[2]-y_pos[0])/(x_pos[2]-float(x_pos[0])))

        y_center = x_center*(y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])) + \
            (y_pos[1]-x_pos[1]*(y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])))

        nimages = self.m_image_in_port.get_shape()[0]
        npix = self.m_image_in_port.get_shape()[1]

        start_time = time.time()
        for i in range(nimages):
            progress(i, nimages, 'Centering the images...', start_time)

            image = self.m_image_in_port[i, ]

            shift_yx = np.array([(float(im_shape[-2]) - 1.) / 2. - y_center,
                                 (float(im_shape[-1]) - 1.) / 2. - x_center])

            if self.m_dither:
                index = np.digitize(i, nframes, right=False) - 1

                shift_yx[0] -= dither_y[index]
                shift_yx[1] -= dither_x[index]

            if npix % 2 == 0 and self.m_size is not None:
                im_tmp = np.zeros((image.shape[0] + 1, image.shape[1] + 1))
                im_tmp[:-1, :-1] = image
                image = im_tmp

                shift_yx[0] += 0.5
                shift_yx[1] += 0.5

            im_shift = shift_image(image, shift_yx, 'spline')

            if self.m_size is not None:
                im_crop = crop_image(im_shift, None, self.m_size)
                self.m_image_out_port.append(im_crop, data_dim=3)
            else:
                self.m_image_out_port.append(im_shift, data_dim=3)

        print(f'Center [x, y] = [{x_center}, {y_center}]')

        history = f'[x, y] = [{round(x_center, 2)}, {round(y_center, 2)}]'
        self.m_image_out_port.copy_attributes(self.m_image_in_port)
        self.m_image_out_port.add_history('WaffleCenteringModule', history)
        self.m_image_out_port.close_port()
Example #9
0
def pca_psf_subtraction(
        images: np.ndarray,
        angles: Optional[np.ndarray],
        pca_number: Union[int, np.int64],
        scales: Optional[np.ndarray] = None,
        pca_sklearn: Optional[PCA] = None,
        im_shape: Optional[tuple] = None,
        indices: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
    """
    Function for PSF subtraction with PCA.

    Parameters
    ----------
    images : np.ndarray
        Stack of images. Also used as reference images if ```pca_sklearn``` is set to None. The
        data should have the original 3D shape if ``pca_sklearn`` is set to None or it should be
        in a 2D reshaped format if ``pca_sklearn`` is not set to None.
    angles : np.ndarray
        Parallactic angles (deg).
    pca_number : int
        Number of principal components.
    scales : np.ndarray, None
        Scaling factors for SDI. Not used if set to None.
    pca_sklearn : sklearn.decomposition.pca.PCA, None
        PCA object with the principal components.
    im_shape : tuple(int, int, int), None
        The original 3D shape of the stack with images. Only required if ``pca_sklearn`` is not set
        to None.
    indices : np.ndarray, None
        Array with the indices of the pixels that are used for the PSF subtraction. All pixels are
        used if set to None.

    Returns
    -------
    np.ndarray
        Residuals of the PSF subtraction.
    np.ndarray
        Derotated residuals of the PSF subtraction.
    """

    if pca_sklearn is None:
        # Create a PCA object if not provided as argument
        pca_sklearn = PCA(n_components=pca_number, svd_solver='arpack')

        # The 3D shape of the array with images
        im_shape = images.shape

        if indices is None:
            # Select the first image and get the unmasked image indices
            im_star = images[0, ].reshape(-1)
            indices = np.where(im_star != 0.)[0]

        # Reshape the images and select the unmasked pixels
        im_reshape = images.reshape(im_shape[0], im_shape[1] * im_shape[2])
        im_reshape = im_reshape[:, indices]

        # Subtract the mean image
        # This is also done by sklearn.decomposition.PCA.fit()
        im_reshape -= np.mean(im_reshape, axis=0)

        # Fit the principal components
        pca_sklearn.fit(im_reshape)

    else:
        # If the PCA object is already there then so are the reshaped data
        im_reshape = np.copy(images)

    # Project the data on the principal components
    # Note that this is the same as sklearn.decomposition.PCA.transform()
    # It is harcoded because the number of components has been adjusted
    pca_rep = np.matmul(pca_sklearn.components_[:pca_number], im_reshape.T)

    # The zeros are added with vstack to account for the components that have not been used for the
    # transformation to the lower-dimensional space, while they were initiated with the PCA object.
    # Since inverse_transform uses the number of initial components, the zeros are added for
    # components > pca_number. These components do not impact the inverse transformation.
    zeros = np.zeros(
        (pca_sklearn.n_components - pca_number, im_reshape.shape[0]))
    pca_rep = np.vstack((pca_rep, zeros)).T

    # Transform the data back to the original space
    psf_model = pca_sklearn.inverse_transform(pca_rep)

    # Create an array with the original shape
    residuals = np.zeros((im_shape[0], im_shape[1] * im_shape[2]))

    # Select all pixel indices if set to None
    if indices is None:
        indices = np.arange(0, im_reshape.shape[1], 1)

    # Subtract the PSF model
    residuals[:, indices] = im_reshape - psf_model

    # Reshape the residuals to the original shape
    residuals = residuals.reshape(im_shape)

    # ----------- back scale images
    scal_cor = np.zeros(residuals.shape)

    if scales is not None:

        # check if the number of parang is equal to the number of images
        if residuals.shape[0] != scales.shape[0]:
            raise ValueError(
                f'The number of images ({residuals.shape[0]}) is not equal to the '
                f'number of wavelengths ({scales.shape[0]}).')

        for i, _ in enumerate(scales):
            # rescaling the images
            swaps = scale_image(residuals[i, ], 1 / scales[i], 1 / scales[i])

            npix_del = scal_cor.shape[-1] - swaps.shape[-1]

            if npix_del == 0:
                scal_cor[i, ] = swaps

            else:
                if npix_del % 2 == 0:
                    npix_del_a = int(npix_del / 2)
                    npix_del_b = int(npix_del / 2)

                else:
                    npix_del_a = int((npix_del - 1) / 2)
                    npix_del_b = int((npix_del + 1) / 2)

                scal_cor[i, npix_del_a:-npix_del_b,
                         npix_del_a:-npix_del_b] = swaps

                if npix_del % 2 == 1:
                    scal_cor[i, ] = shift_image(scal_cor[i, ], (0.5, 0.5),
                                                interpolation='spline')

    else:
        scal_cor = residuals

    res_rot = np.zeros(residuals.shape)

    if angles is not None:

        # Check if the number of parang is equal to the number of images
        if residuals.shape[0] != angles.shape[0]:
            raise ValueError(
                f'The number of images ({residuals.shape[0]}) is not equal to the '
                f'number of parallactic angles ({angles.shape[0]}).')

        for j, item in enumerate(angles):
            res_rot[j, ] = rotate(scal_cor[j, ], item, reshape=False)

    else:
        res_rot = scal_cor

    return scal_cor, res_rot
Example #10
0
        def _image_shift(image, shift, interpolation):

            return shift_image(image, shift, interpolation)
Example #11
0
    def run(self) -> None:
        """
        Run method of the module. Locates the position of the calibration spots in the center
        frame. From the four spots, the position of the star behind the coronagraph is fitted,
        and the images are shifted and cropped.

        Returns
        -------
        NoneType
            None
        """
        @typechecked
        def _get_center(
            image_number: int, center: Optional[Tuple[int, int]]
        ) -> Tuple[np.ndarray, Tuple[int, int]]:

            if center_shape[-3] > 1:
                warnings.warn(
                    'Multiple center images found. Using the first image of the stack.'
                )

            if ndim == 3:
                center_frame = self.m_center_in_port[0, ]
            elif ndim == 4:
                center_frame = self.m_center_in_port[image_number, 0, ]

            if center is None:
                center = center_pixel(center_frame)
            else:
                center = (int(np.floor(center[0])), int(np.floor(center[1])))

            return center_frame, center

        center_shape = self.m_center_in_port.get_shape()
        im_shape = self.m_image_in_port.get_shape()
        ndim = self.m_image_in_port.get_ndim()

        center_frame, self.m_center = _get_center(0, self.m_center)

        # Read in wavelength information or set it to default values
        if ndim == 4:
            wavelength = self.m_image_in_port.get_attribute('WAVELENGTH')

            if wavelength is None:
                raise ValueError(
                    'The wavelength information is required to centre IFS data. '
                    'Please add it via the WavelengthReadingModule before using '
                    'the WaffleCenteringModule.')

            if im_shape[0] != center_shape[0]:
                raise ValueError(
                    f'Number of science wavelength channels: {im_shape[0]}. '
                    f'Number of center wavelength channels: {center_shape[0]}. '
                    'Exactly one center image per wavelength is required.')

            wavelength_min = np.min(wavelength)

        elif ndim == 3:
            # for none ifs data, use default value
            wavelength = [1.]
            wavelength_min = 1.

        # check if science and center images have the same shape
        if im_shape[-2:] != center_shape[-2:]:
            raise ValueError(
                'Science and center images should have the same shape.')

        # Setting angle via pattern (used for backwards compability)
        if self.m_pattern is not None:

            if self.m_pattern == 'x':
                self.m_angle = 45.

            elif self.m_pattern == '+':
                self.m_angle = 0.

            else:
                raise ValueError(
                    f'The pattern {self.m_pattern} is not valid. Please select '
                    f'either \'x\' or \'+\'.')

            warnings.warn(
                f'The \'pattern\' parameter will be deprecated in a future release. '
                f'Please Use the \'angle\' parameter instead and set it to '
                f'{self.m_angle} degrees.', DeprecationWarning)

        pixscale = self.m_image_in_port.get_attribute('PIXSCALE')

        self.m_sigma /= pixscale

        if self.m_size is not None:
            self.m_size = int(math.ceil(self.m_size / pixscale))

        if self.m_dither:
            dither_x = self.m_image_in_port.get_attribute('DITHER_X')
            dither_y = self.m_image_in_port.get_attribute('DITHER_Y')

            nframes = self.m_image_in_port.get_attribute('NFRAMES')
            nframes = np.cumsum(nframes)
            nframes = np.insert(nframes, 0, 0)

        # size of center image, only works with odd value
        ref_image_size = 21

        # Arrays for the positions
        x_pos = np.zeros(4)
        y_pos = np.zeros(4)

        # Arrays for the center position for each wavelength
        x_center = np.zeros((len(wavelength)))
        y_center = np.zeros((len(wavelength)))

        # Loop for 4 waffle spots
        for w, wave_nr in enumerate(wavelength):

            # Prapre centering frame
            center_frame, _ = _get_center(w, self.m_center)

            center_frame_unsharp = center_frame - gaussian_filter(
                input=center_frame, sigma=self.m_sigma)

            for i in range(4):
                # Approximate positions of waffle spots
                radius = self.m_radius * wave_nr / wavelength_min

                x_0 = np.floor(self.m_center[0] + radius *
                               np.cos(self.m_angle * np.pi / 180 + np.pi / 4. *
                                      (2 * i)))

                y_0 = np.floor(self.m_center[1] + radius *
                               np.sin(self.m_angle * np.pi / 180 + np.pi / 4. *
                                      (2 * i)))

                tmp_center_frame = crop_image(image=center_frame_unsharp,
                                              center=(int(y_0), int(x_0)),
                                              size=ref_image_size)

                # find maximum in tmp image
                coords = np.unravel_index(indices=np.argmax(tmp_center_frame),
                                          shape=tmp_center_frame.shape)

                y_max, x_max = coords[0], coords[1]

                pixmax = tmp_center_frame[y_max, x_max]
                max_pos = np.array([x_max, y_max]).reshape(1, 2)

                # Check whether it is the correct maximum: second brightest pixel should be nearby
                tmp_center_frame[y_max, x_max] = 0.

                # introduce distance parameter
                dist = np.inf

                while dist > 2:
                    coords = np.unravel_index(
                        indices=np.argmax(tmp_center_frame),
                        shape=tmp_center_frame.shape)

                    y_max_new, x_max_new = coords[0], coords[1]

                    pixmax_new = tmp_center_frame[y_max_new, x_max_new]

                    # Caculate minimal distance to previous points
                    tmp_center_frame[y_max_new, x_max_new] = 0.

                    dist = np.amin(
                        np.linalg.norm(np.vstack((max_pos[:, 0] - x_max_new,
                                                  max_pos[:, 1] - y_max_new)),
                                       axis=0))

                    if dist <= 2 and pixmax_new < pixmax:
                        break

                    max_pos = np.vstack((max_pos, [x_max_new, y_max_new]))

                    x_max = x_max_new
                    y_max = y_max_new
                    pixmax = pixmax_new

                x_0 = x_0 - (ref_image_size - 1) / 2 + x_max
                y_0 = y_0 - (ref_image_size - 1) / 2 + y_max

                # create reference image around determined maximum
                ref_center_frame = crop_image(image=center_frame_unsharp,
                                              center=(int(y_0), int(x_0)),
                                              size=ref_image_size)

                # Fit the data using astropy.modeling
                gauss_init = models.Gaussian2D(
                    amplitude=np.amax(ref_center_frame),
                    x_mean=x_0,
                    y_mean=y_0,
                    x_stddev=1.,
                    y_stddev=1.,
                    theta=0.)

                fit_gauss = fitting.LevMarLSQFitter()

                y_grid, x_grid = np.mgrid[y_0 - (ref_image_size - 1) / 2:y_0 +
                                          (ref_image_size - 1) / 2 + 1,
                                          x_0 - (ref_image_size - 1) / 2:x_0 +
                                          (ref_image_size - 1) / 2 + 1]

                gauss = fit_gauss(gauss_init, x_grid, y_grid, ref_center_frame)

                x_pos[i] = gauss.x_mean.value
                y_pos[i] = gauss.y_mean.value

            # Find star position as intersection of two lines

            x_center[w] = ((y_pos[0]-x_pos[0]*(y_pos[2]-y_pos[0])/(x_pos[2]-float(x_pos[0]))) -
                           (y_pos[1]-x_pos[1]*(y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])))) / \
                          ((y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])) -
                           (y_pos[2]-y_pos[0])/(x_pos[2]-float(x_pos[0])))

            y_center[w] = x_center[w]*(y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])) + \
                (y_pos[1]-x_pos[1]*(y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])))

        # Adjust science images
        nimages = self.m_image_in_port.get_shape()[-3]
        npix = self.m_image_in_port.get_shape()[-2]
        nwavelengths = len(wavelength)

        start_time = time.time()

        for i in range(nimages):
            im_storage = []
            for j in range(nwavelengths):
                im_index = i * nwavelengths + j

                progress(im_index, nimages * nwavelengths,
                         'Centering the images...', start_time)

                if ndim == 3:
                    image = self.m_image_in_port[i, ]
                elif ndim == 4:
                    image = self.m_image_in_port[j, i, ]

                shift_yx = np.array([
                    (float(im_shape[-2]) - 1.) / 2. - y_center[j],
                    (float(im_shape[-1]) - 1.) / 2. - x_center[j]
                ])

                if self.m_dither:
                    index = np.digitize(i, nframes, right=False) - 1

                    shift_yx[0] -= dither_y[index]
                    shift_yx[1] -= dither_x[index]

                if npix % 2 == 0 and self.m_size is not None:
                    im_tmp = np.zeros((image.shape[0] + 1, image.shape[1] + 1))
                    im_tmp[:-1, :-1] = image
                    image = im_tmp

                    shift_yx[0] += 0.5
                    shift_yx[1] += 0.5

                im_shift = shift_image(image, shift_yx, 'spline')

                if self.m_size is not None:
                    im_crop = crop_image(im_shift, None, self.m_size)
                    im_storage.append(im_crop)
                else:
                    im_storage.append(im_shift)

            if ndim == 3:
                self.m_image_out_port.append(im_storage[0], data_dim=3)
            elif ndim == 4:
                self.m_image_out_port.append(np.asarray(im_storage),
                                             data_dim=4)

        print(f'Center [x, y] = [{x_center}, {y_center}]')

        history = f'[x, y] = [{round(x_center[j], 2)}, {round(y_center[j], 2)}]'
        self.m_image_out_port.copy_attributes(self.m_image_in_port)
        self.m_image_out_port.add_history('WaffleCenteringModule', history)
        self.m_image_out_port.close_port()
Example #12
0
def pca_psf_subtraction(
        images: np.ndarray,
        angles: Optional[np.ndarray],
        pca_number: Union[int, np.int64],
        scales: Optional[np.ndarray] = None,
        pca_sklearn: Optional[PCA] = None,
        im_shape: Optional[tuple] = None,
        indices: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
    """
    Function for PSF subtraction with PCA.

    Parameters
    ----------
    images : np.ndarray
        Stack of images. Also used as reference images if `pca_sklearn` is set to None. Should be
        in the original 3D shape if `pca_sklearn` is set to None or in the 2D reshaped format if
        `pca_sklearn` is not set to None.
    angles : np.ndarray, None
        Derotation angles (deg). The images are not derotated (e.g. for SDI) if set to None.
    pca_number : int
        Number of principal components used for the PSF model.
    scales : np.ndarray, None
        Scaling factors for SDI. Not used if set to None.
    pca_sklearn : sklearn.decomposition.pca.PCA, None
        PCA decomposition of the input data.
    im_shape : tuple(int, int, int), None
        Original shape of the stack with images. Required if `pca_sklearn` is not set to None.
    indices : np.ndarray, None
        Non-masked image indices. All pixels are used if set to None.

    Returns
    -------
    np.ndarray
        Residuals of the PSF subtraction.
    np.ndarray
        Derotated residuals of the PSF subtraction.
    """

    if pca_sklearn is None:
        pca_sklearn = PCA(n_components=pca_number, svd_solver='arpack')

        im_shape = images.shape

        if indices is None:
            # select the first image and get the unmasked image indices
            im_star = images[0, ].reshape(-1)
            indices = np.where(im_star != 0.)[0]

        # reshape the images and select the unmasked pixels
        im_reshape = images.reshape(im_shape[0], im_shape[1] * im_shape[2])
        im_reshape = im_reshape[:, indices]

        # subtract mean image
        im_reshape -= np.mean(im_reshape, axis=0)

        # create pca basis
        pca_sklearn.fit(im_reshape)

    else:
        im_reshape = np.copy(images)

    # create pca representation
    zeros = np.zeros(
        (pca_sklearn.n_components - pca_number, im_reshape.shape[0]))
    pca_rep = np.matmul(pca_sklearn.components_[:pca_number], im_reshape.T)
    pca_rep = np.vstack((pca_rep, zeros)).T

    # create psf model
    psf_model = pca_sklearn.inverse_transform(pca_rep)

    # create original array size
    residuals = np.zeros((im_shape[0], im_shape[1] * im_shape[2]))

    # subtract the psf model
    if indices is None:
        indices = np.arange(0, im_reshape.shape[1], 1)

    residuals[:, indices] = im_reshape - psf_model

    # reshape to the original image size
    residuals = residuals.reshape(im_shape)

    # ----------- back scale images
    scal_cor = np.zeros(residuals.shape)

    if scales is not None:

        # check if the number of parang is equal to the number of images
        if residuals.shape[0] != scales.shape[0]:
            raise ValueError(
                f'The number of images ({residuals.shape[0]}) is not equal to the '
                f'number of wavelengths ({scales.shape[0]}).')

        for i, _ in enumerate(scales):
            # rescaling the images
            swaps = scale_image(residuals[i, ], 1 / scales[i], 1 / scales[i])

            npix_del = scal_cor.shape[-1] - swaps.shape[-1]

            if npix_del == 0:
                scal_cor[i, ] = swaps

            else:
                if npix_del % 2 == 0:
                    npix_del_a = int(npix_del / 2)
                    npix_del_b = int(npix_del / 2)

                else:
                    npix_del_a = int((npix_del - 1) / 2)
                    npix_del_b = int((npix_del + 1) / 2)

                scal_cor[i, npix_del_a:-npix_del_b,
                         npix_del_a:-npix_del_b] = swaps

                if npix_del % 2 == 1:
                    scal_cor[i, ] = shift_image(scal_cor[i, ], (0.5, 0.5),
                                                interpolation='spline')

    else:
        scal_cor = residuals

    res_rot = np.zeros(residuals.shape)

    if angles is not None:

        # Check if the number of parang is equal to the number of images
        if residuals.shape[0] != angles.shape[0]:
            raise ValueError(
                f'The number of images ({residuals.shape[0]}) is not equal to the '
                f'number of parallactic angles ({angles.shape[0]}).')

        for j, item in enumerate(angles):
            res_rot[j, ] = rotate(scal_cor[j, ], item, reshape=False)

    else:
        res_rot = scal_cor

    return scal_cor, res_rot