def _centering(image, popt): if self.m_method == "full": popt = _least_squares(np.copy(image)) return shift_image(image, (-popt[1], -popt[0]), self.m_interpolation)
def run(self) -> None: """ Run method of the module. Shifts an image with a fifth order spline, bilinear, or a Fourier shift interpolation. Returns ------- NoneType None """ constant = True # read the fit results from the self.m_fit_in_port if available if self.m_fit_in_port is not None: self.m_shift = -1. * self.m_fit_in_port[:, [0, 2]] # (x, y) self.m_shift = self.m_shift[:, [1, 0]] # (y, x) # check if data in self.m_fit_in_port is constant for all images using the # constant flag if not np.allclose( self.m_fit_in_port.get_all() - self.m_fit_in_port[0, ], 0.0): constant = False if constant: # if the offset is constant then use the first element for all images self.m_shift = self.m_shift[0, ] else: # if the offset is not constant, then apply the shifts to each frame individually for i, shift in enumerate(self.m_shift): shifted_image = shift_image(self.m_image_in_port[i, ], shift, self.m_interpolation) # append the shifted images to the selt.m_image_out_port database entry self.m_image_out_port.append(shifted_image, data_dim=3) mean_shift = np.mean(self.m_shift, axis=0) history = f'shift_xy = {mean_shift[0]:.2f}, {mean_shift[1]:.2f}' # apply a constant shift if constant: self.apply_function_to_images(shift_image, self.m_image_in_port, self.m_image_out_port, 'Shifting the images', func_args=(self.m_shift, self.m_interpolation)) # if self.m_fit_in_port is None or constant: history = f'shift_xy = {self.m_shift[0]:.2f}, {self.m_shift[1]:.2f}' self.m_image_out_port.copy_attributes(self.m_image_in_port) self.m_image_out_port.add_history('ShiftImagesModule', history) self.m_image_out_port.close_port()
def run(self) -> None: """ Run method of the module. Normalizes the images for the different filter widths, upscales the images, and crops the images to the initial image shape in order to align the PSF patterns. Returns ------- NoneType None """ self.m_image_out_port.del_all_data() self.m_image_out_port.del_all_attributes() wvl_factor = self.m_line_wvl / self.m_cnt_wvl width_factor = self.m_line_width / self.m_cnt_width nimages = self.m_image_in_port.get_shape()[0] start_time = time.time() for i in range(nimages): progress(i, nimages, 'Running SDIpreparationModule...', start_time) image = self.m_image_in_port[i, ] im_scale = width_factor * scale_image(image, wvl_factor, wvl_factor) if i == 0: npix_del = im_scale.shape[-1] - image.shape[-1] if npix_del % 2 == 0: npix_del_a = int(npix_del / 2) npix_del_b = int(npix_del / 2) else: npix_del_a = int((npix_del - 1) / 2) npix_del_b = int((npix_del + 1) / 2) im_crop = im_scale[npix_del_a:-npix_del_b, npix_del_a:-npix_del_b] if npix_del % 2 == 1: im_crop = shift_image(im_crop, (-0.5, -0.5), interpolation='spline') self.m_image_out_port.append(im_crop, data_dim=3) sys.stdout.write('Running SDIpreparationModule... [DONE]\n') sys.stdout.flush() history = f'(line, continuum) = ({self.m_line_wvl}, {self.m_cnt_wvl})' self.m_image_out_port.copy_attributes(self.m_image_in_port) self.m_image_out_port.add_history('SDIpreparationModule', history) self.m_image_in_port.close_port()
def align_image(image_in: np.ndarray, im_index: int, interpolation: str, accuracy: float, resize: Optional[float], num_references: int, subframe: Optional[float], ref_images_reshape: np.ndarray, ref_images_shape: Tuple[int, int, int]) -> np.ndarray: offset = np.array([0., 0.]) # Reshape the reference images back to their original 3D shape # The original shape can not be used directly because of util.module.update_arguments ref_images = ref_images_reshape.reshape(ref_images_shape) for i in range(num_references): if subframe is None: tmp_offset, _, _ = phase_cross_correlation(ref_images[i, :, :], image_in, upsample_factor=accuracy) else: sub_in = crop_image(image_in, None, subframe) sub_ref = crop_image(ref_images[i, :, :], None, subframe) tmp_offset, _, _ = phase_cross_correlation(sub_ref, sub_in, upsample_factor=accuracy) offset += tmp_offset offset /= float(num_references) if resize is not None: offset *= resize sum_before = np.sum(image_in) tmp_image = rescale(image_in, (resize, resize), order=5, mode='reflect', multichannel=False, anti_aliasing=True) sum_after = np.sum(tmp_image) # Conserve flux because the rescale function normalizes all values to [0:1]. tmp_image = tmp_image*(sum_before/sum_after) else: tmp_image = image_in return shift_image(tmp_image, offset, interpolation)
def sdi_scaling(image_in: np.ndarray, scaling: np.ndarray) -> np.ndarray: """ Function to rescale the images by their wavelength ratios. Parameters ---------- image_in : np.ndarray Data to rescale scaling : np.ndarray Scaling factors. Returns ------- np.ndarray Rescaled images with the same shape as ``image_in``. """ if image_in.shape[0] != scaling.shape[0]: raise ValueError( 'The number of wavelengths is not equal to the number of available ' 'scaling factors.') image_out = np.zeros(image_in.shape) for i in range(image_in.shape[0]): swaps = scale_image(image_in[i, ], scaling[i], scaling[i]) npix_del = swaps.shape[-1] - image_out.shape[-1] if npix_del == 0: image_out[i, ] = swaps else: if npix_del % 2 == 0: npix_del_a = int(npix_del / 2) npix_del_b = int(npix_del / 2) else: npix_del_a = int((npix_del - 1) / 2) npix_del_b = int((npix_del + 1) / 2) image_out[i, ] = swaps[npix_del_a:-npix_del_b, npix_del_a:-npix_del_b] if npix_del % 2 == 1: image_out[i, ] = shift_image(image_out[i, ], (-0.5, -0.5), interpolation='spline') return image_out
def _align_image(image_in): offset = np.array([0., 0.]) for i in range(self.m_num_references): if self.m_subframe is None: tmp_offset, _, _ = register_translation( ref_images[i, :, :], image_in, upsample_factor=self.m_accuracy) else: sub_in = crop_image(image_in, None, self.m_subframe) sub_ref = crop_image(ref_images[i, :, :], None, self.m_subframe) tmp_offset, _, _ = register_translation( sub_ref, sub_in, upsample_factor=self.m_accuracy) offset += tmp_offset offset /= float(self.m_num_references) if self.m_resize is not None: offset *= self.m_resize sum_before = np.sum(image_in) tmp_image = rescale(image=np.asarray(image_in, dtype=np.float64), scale=(self.m_resize, self.m_resize), order=5, mode='reflect', anti_aliasing=True, multichannel=False) sum_after = np.sum(tmp_image) # Conserve flux because the rescale function normalizes all values to [0:1]. tmp_image = tmp_image * (sum_before / sum_after) else: tmp_image = image_in return shift_image(tmp_image, offset, self.m_interpolation)
def apply_shift(image_in: np.ndarray, im_index: int, shift: Union[Tuple[float, float], np.ndarray], interpolation: str) -> np.ndarray: return shift_image(image_in, shift, interpolation)
def run(self) -> None: """ Run method of the module. Locates the position of the calibration spots in the center frame. From the four spots, the position of the star behind the coronagraph is fitted, and the images are shifted and cropped. Returns ------- NoneType None """ def _get_center(center): center_frame = self.m_center_in_port[0, ] if center_shape[0] > 1: warnings.warn( 'Multiple center images found. Using the first image of the stack.' ) if center is None: center = center_pixel(center_frame) else: center = (np.floor(center[0]), np.floor(center[1])) return center_frame, center self.m_image_out_port.del_all_data() self.m_image_out_port.del_all_attributes() center_shape = self.m_center_in_port.get_shape() im_shape = self.m_image_in_port.get_shape() center_frame, self.m_center = _get_center(self.m_center) if im_shape[-2:] != center_shape[-2:]: raise ValueError( 'Science and center images should have the same shape.') pixscale = self.m_image_in_port.get_attribute('PIXSCALE') self.m_sigma /= pixscale if self.m_size is not None: self.m_size = int(math.ceil(self.m_size / pixscale)) if self.m_dither: dither_x = self.m_image_in_port.get_attribute('DITHER_X') dither_y = self.m_image_in_port.get_attribute('DITHER_Y') nframes = self.m_image_in_port.get_attribute('NFRAMES') nframes = np.cumsum(nframes) nframes = np.insert(nframes, 0, 0) center_frame_unsharp = center_frame - gaussian_filter( input=center_frame, sigma=self.m_sigma) # size of center image, only works with odd value ref_image_size = 21 # Arrays for the positions x_pos = np.zeros(4) y_pos = np.zeros(4) # Loop for 4 waffle spots for i in range(4): # Approximate positions of waffle spots if self.m_pattern == 'x': x_0 = np.floor(self.m_center[0] + self.m_radius * np.cos(np.pi / 4. * (2 * i + 1))) y_0 = np.floor(self.m_center[1] + self.m_radius * np.sin(np.pi / 4. * (2 * i + 1))) elif self.m_pattern == '+': x_0 = np.floor(self.m_center[0] + self.m_radius * np.cos(np.pi / 4. * (2 * i))) y_0 = np.floor(self.m_center[1] + self.m_radius * np.sin(np.pi / 4. * (2 * i))) tmp_center_frame = crop_image(image=center_frame_unsharp, center=(int(y_0), int(x_0)), size=ref_image_size) # find maximum in tmp image coords = np.unravel_index(indices=np.argmax(tmp_center_frame), shape=tmp_center_frame.shape) y_max, x_max = coords[0], coords[1] pixmax = tmp_center_frame[y_max, x_max] max_pos = np.array([x_max, y_max]).reshape(1, 2) # Check whether it is the correct maximum: second brightest pixel should be nearby tmp_center_frame[y_max, x_max] = 0. # introduce distance parameter dist = np.inf while dist > 2: coords = np.unravel_index(indices=np.argmax(tmp_center_frame), shape=tmp_center_frame.shape) y_max_new, x_max_new = coords[0], coords[1] pixmax_new = tmp_center_frame[y_max_new, x_max_new] # Caculate minimal distance to previous points tmp_center_frame[y_max_new, x_max_new] = 0. dist = np.amin( np.linalg.norm(np.vstack((max_pos[:, 0] - x_max_new, max_pos[:, 1] - y_max_new)), axis=0)) if dist <= 2 and pixmax_new < pixmax: break max_pos = np.vstack((max_pos, [x_max_new, y_max_new])) x_max = x_max_new y_max = y_max_new pixmax = pixmax_new x_0 = x_0 - (ref_image_size - 1) / 2 + x_max y_0 = y_0 - (ref_image_size - 1) / 2 + y_max # create reference image around determined maximum ref_center_frame = crop_image(image=center_frame_unsharp, center=(int(y_0), int(x_0)), size=ref_image_size) # Fit the data using astropy.modeling gauss_init = models.Gaussian2D(amplitude=np.amax(ref_center_frame), x_mean=x_0, y_mean=y_0, x_stddev=1., y_stddev=1., theta=0.) fit_gauss = fitting.LevMarLSQFitter() y_grid, x_grid = np.mgrid[y_0 - (ref_image_size - 1) / 2:y_0 + (ref_image_size - 1) / 2 + 1, x_0 - (ref_image_size - 1) / 2:x_0 + (ref_image_size - 1) / 2 + 1] gauss = fit_gauss(gauss_init, x_grid, y_grid, ref_center_frame) x_pos[i] = gauss.x_mean.value y_pos[i] = gauss.y_mean.value # Find star position as intersection of two lines x_center = ((y_pos[0]-x_pos[0]*(y_pos[2]-y_pos[0])/(x_pos[2]-float(x_pos[0]))) - (y_pos[1]-x_pos[1]*(y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])))) / \ ((y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])) - (y_pos[2]-y_pos[0])/(x_pos[2]-float(x_pos[0]))) y_center = x_center*(y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])) + \ (y_pos[1]-x_pos[1]*(y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3]))) nimages = self.m_image_in_port.get_shape()[0] npix = self.m_image_in_port.get_shape()[1] start_time = time.time() for i in range(nimages): progress(i, nimages, 'Centering the images...', start_time) image = self.m_image_in_port[i, ] shift_yx = np.array([(float(im_shape[-2]) - 1.) / 2. - y_center, (float(im_shape[-1]) - 1.) / 2. - x_center]) if self.m_dither: index = np.digitize(i, nframes, right=False) - 1 shift_yx[0] -= dither_y[index] shift_yx[1] -= dither_x[index] if npix % 2 == 0 and self.m_size is not None: im_tmp = np.zeros((image.shape[0] + 1, image.shape[1] + 1)) im_tmp[:-1, :-1] = image image = im_tmp shift_yx[0] += 0.5 shift_yx[1] += 0.5 im_shift = shift_image(image, shift_yx, 'spline') if self.m_size is not None: im_crop = crop_image(im_shift, None, self.m_size) self.m_image_out_port.append(im_crop, data_dim=3) else: self.m_image_out_port.append(im_shift, data_dim=3) print(f'Center [x, y] = [{x_center}, {y_center}]') history = f'[x, y] = [{round(x_center, 2)}, {round(y_center, 2)}]' self.m_image_out_port.copy_attributes(self.m_image_in_port) self.m_image_out_port.add_history('WaffleCenteringModule', history) self.m_image_out_port.close_port()
def pca_psf_subtraction( images: np.ndarray, angles: Optional[np.ndarray], pca_number: Union[int, np.int64], scales: Optional[np.ndarray] = None, pca_sklearn: Optional[PCA] = None, im_shape: Optional[tuple] = None, indices: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]: """ Function for PSF subtraction with PCA. Parameters ---------- images : np.ndarray Stack of images. Also used as reference images if ```pca_sklearn``` is set to None. The data should have the original 3D shape if ``pca_sklearn`` is set to None or it should be in a 2D reshaped format if ``pca_sklearn`` is not set to None. angles : np.ndarray Parallactic angles (deg). pca_number : int Number of principal components. scales : np.ndarray, None Scaling factors for SDI. Not used if set to None. pca_sklearn : sklearn.decomposition.pca.PCA, None PCA object with the principal components. im_shape : tuple(int, int, int), None The original 3D shape of the stack with images. Only required if ``pca_sklearn`` is not set to None. indices : np.ndarray, None Array with the indices of the pixels that are used for the PSF subtraction. All pixels are used if set to None. Returns ------- np.ndarray Residuals of the PSF subtraction. np.ndarray Derotated residuals of the PSF subtraction. """ if pca_sklearn is None: # Create a PCA object if not provided as argument pca_sklearn = PCA(n_components=pca_number, svd_solver='arpack') # The 3D shape of the array with images im_shape = images.shape if indices is None: # Select the first image and get the unmasked image indices im_star = images[0, ].reshape(-1) indices = np.where(im_star != 0.)[0] # Reshape the images and select the unmasked pixels im_reshape = images.reshape(im_shape[0], im_shape[1] * im_shape[2]) im_reshape = im_reshape[:, indices] # Subtract the mean image # This is also done by sklearn.decomposition.PCA.fit() im_reshape -= np.mean(im_reshape, axis=0) # Fit the principal components pca_sklearn.fit(im_reshape) else: # If the PCA object is already there then so are the reshaped data im_reshape = np.copy(images) # Project the data on the principal components # Note that this is the same as sklearn.decomposition.PCA.transform() # It is harcoded because the number of components has been adjusted pca_rep = np.matmul(pca_sklearn.components_[:pca_number], im_reshape.T) # The zeros are added with vstack to account for the components that have not been used for the # transformation to the lower-dimensional space, while they were initiated with the PCA object. # Since inverse_transform uses the number of initial components, the zeros are added for # components > pca_number. These components do not impact the inverse transformation. zeros = np.zeros( (pca_sklearn.n_components - pca_number, im_reshape.shape[0])) pca_rep = np.vstack((pca_rep, zeros)).T # Transform the data back to the original space psf_model = pca_sklearn.inverse_transform(pca_rep) # Create an array with the original shape residuals = np.zeros((im_shape[0], im_shape[1] * im_shape[2])) # Select all pixel indices if set to None if indices is None: indices = np.arange(0, im_reshape.shape[1], 1) # Subtract the PSF model residuals[:, indices] = im_reshape - psf_model # Reshape the residuals to the original shape residuals = residuals.reshape(im_shape) # ----------- back scale images scal_cor = np.zeros(residuals.shape) if scales is not None: # check if the number of parang is equal to the number of images if residuals.shape[0] != scales.shape[0]: raise ValueError( f'The number of images ({residuals.shape[0]}) is not equal to the ' f'number of wavelengths ({scales.shape[0]}).') for i, _ in enumerate(scales): # rescaling the images swaps = scale_image(residuals[i, ], 1 / scales[i], 1 / scales[i]) npix_del = scal_cor.shape[-1] - swaps.shape[-1] if npix_del == 0: scal_cor[i, ] = swaps else: if npix_del % 2 == 0: npix_del_a = int(npix_del / 2) npix_del_b = int(npix_del / 2) else: npix_del_a = int((npix_del - 1) / 2) npix_del_b = int((npix_del + 1) / 2) scal_cor[i, npix_del_a:-npix_del_b, npix_del_a:-npix_del_b] = swaps if npix_del % 2 == 1: scal_cor[i, ] = shift_image(scal_cor[i, ], (0.5, 0.5), interpolation='spline') else: scal_cor = residuals res_rot = np.zeros(residuals.shape) if angles is not None: # Check if the number of parang is equal to the number of images if residuals.shape[0] != angles.shape[0]: raise ValueError( f'The number of images ({residuals.shape[0]}) is not equal to the ' f'number of parallactic angles ({angles.shape[0]}).') for j, item in enumerate(angles): res_rot[j, ] = rotate(scal_cor[j, ], item, reshape=False) else: res_rot = scal_cor return scal_cor, res_rot
def _image_shift(image, shift, interpolation): return shift_image(image, shift, interpolation)
def run(self) -> None: """ Run method of the module. Locates the position of the calibration spots in the center frame. From the four spots, the position of the star behind the coronagraph is fitted, and the images are shifted and cropped. Returns ------- NoneType None """ @typechecked def _get_center( image_number: int, center: Optional[Tuple[int, int]] ) -> Tuple[np.ndarray, Tuple[int, int]]: if center_shape[-3] > 1: warnings.warn( 'Multiple center images found. Using the first image of the stack.' ) if ndim == 3: center_frame = self.m_center_in_port[0, ] elif ndim == 4: center_frame = self.m_center_in_port[image_number, 0, ] if center is None: center = center_pixel(center_frame) else: center = (int(np.floor(center[0])), int(np.floor(center[1]))) return center_frame, center center_shape = self.m_center_in_port.get_shape() im_shape = self.m_image_in_port.get_shape() ndim = self.m_image_in_port.get_ndim() center_frame, self.m_center = _get_center(0, self.m_center) # Read in wavelength information or set it to default values if ndim == 4: wavelength = self.m_image_in_port.get_attribute('WAVELENGTH') if wavelength is None: raise ValueError( 'The wavelength information is required to centre IFS data. ' 'Please add it via the WavelengthReadingModule before using ' 'the WaffleCenteringModule.') if im_shape[0] != center_shape[0]: raise ValueError( f'Number of science wavelength channels: {im_shape[0]}. ' f'Number of center wavelength channels: {center_shape[0]}. ' 'Exactly one center image per wavelength is required.') wavelength_min = np.min(wavelength) elif ndim == 3: # for none ifs data, use default value wavelength = [1.] wavelength_min = 1. # check if science and center images have the same shape if im_shape[-2:] != center_shape[-2:]: raise ValueError( 'Science and center images should have the same shape.') # Setting angle via pattern (used for backwards compability) if self.m_pattern is not None: if self.m_pattern == 'x': self.m_angle = 45. elif self.m_pattern == '+': self.m_angle = 0. else: raise ValueError( f'The pattern {self.m_pattern} is not valid. Please select ' f'either \'x\' or \'+\'.') warnings.warn( f'The \'pattern\' parameter will be deprecated in a future release. ' f'Please Use the \'angle\' parameter instead and set it to ' f'{self.m_angle} degrees.', DeprecationWarning) pixscale = self.m_image_in_port.get_attribute('PIXSCALE') self.m_sigma /= pixscale if self.m_size is not None: self.m_size = int(math.ceil(self.m_size / pixscale)) if self.m_dither: dither_x = self.m_image_in_port.get_attribute('DITHER_X') dither_y = self.m_image_in_port.get_attribute('DITHER_Y') nframes = self.m_image_in_port.get_attribute('NFRAMES') nframes = np.cumsum(nframes) nframes = np.insert(nframes, 0, 0) # size of center image, only works with odd value ref_image_size = 21 # Arrays for the positions x_pos = np.zeros(4) y_pos = np.zeros(4) # Arrays for the center position for each wavelength x_center = np.zeros((len(wavelength))) y_center = np.zeros((len(wavelength))) # Loop for 4 waffle spots for w, wave_nr in enumerate(wavelength): # Prapre centering frame center_frame, _ = _get_center(w, self.m_center) center_frame_unsharp = center_frame - gaussian_filter( input=center_frame, sigma=self.m_sigma) for i in range(4): # Approximate positions of waffle spots radius = self.m_radius * wave_nr / wavelength_min x_0 = np.floor(self.m_center[0] + radius * np.cos(self.m_angle * np.pi / 180 + np.pi / 4. * (2 * i))) y_0 = np.floor(self.m_center[1] + radius * np.sin(self.m_angle * np.pi / 180 + np.pi / 4. * (2 * i))) tmp_center_frame = crop_image(image=center_frame_unsharp, center=(int(y_0), int(x_0)), size=ref_image_size) # find maximum in tmp image coords = np.unravel_index(indices=np.argmax(tmp_center_frame), shape=tmp_center_frame.shape) y_max, x_max = coords[0], coords[1] pixmax = tmp_center_frame[y_max, x_max] max_pos = np.array([x_max, y_max]).reshape(1, 2) # Check whether it is the correct maximum: second brightest pixel should be nearby tmp_center_frame[y_max, x_max] = 0. # introduce distance parameter dist = np.inf while dist > 2: coords = np.unravel_index( indices=np.argmax(tmp_center_frame), shape=tmp_center_frame.shape) y_max_new, x_max_new = coords[0], coords[1] pixmax_new = tmp_center_frame[y_max_new, x_max_new] # Caculate minimal distance to previous points tmp_center_frame[y_max_new, x_max_new] = 0. dist = np.amin( np.linalg.norm(np.vstack((max_pos[:, 0] - x_max_new, max_pos[:, 1] - y_max_new)), axis=0)) if dist <= 2 and pixmax_new < pixmax: break max_pos = np.vstack((max_pos, [x_max_new, y_max_new])) x_max = x_max_new y_max = y_max_new pixmax = pixmax_new x_0 = x_0 - (ref_image_size - 1) / 2 + x_max y_0 = y_0 - (ref_image_size - 1) / 2 + y_max # create reference image around determined maximum ref_center_frame = crop_image(image=center_frame_unsharp, center=(int(y_0), int(x_0)), size=ref_image_size) # Fit the data using astropy.modeling gauss_init = models.Gaussian2D( amplitude=np.amax(ref_center_frame), x_mean=x_0, y_mean=y_0, x_stddev=1., y_stddev=1., theta=0.) fit_gauss = fitting.LevMarLSQFitter() y_grid, x_grid = np.mgrid[y_0 - (ref_image_size - 1) / 2:y_0 + (ref_image_size - 1) / 2 + 1, x_0 - (ref_image_size - 1) / 2:x_0 + (ref_image_size - 1) / 2 + 1] gauss = fit_gauss(gauss_init, x_grid, y_grid, ref_center_frame) x_pos[i] = gauss.x_mean.value y_pos[i] = gauss.y_mean.value # Find star position as intersection of two lines x_center[w] = ((y_pos[0]-x_pos[0]*(y_pos[2]-y_pos[0])/(x_pos[2]-float(x_pos[0]))) - (y_pos[1]-x_pos[1]*(y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])))) / \ ((y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])) - (y_pos[2]-y_pos[0])/(x_pos[2]-float(x_pos[0]))) y_center[w] = x_center[w]*(y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3])) + \ (y_pos[1]-x_pos[1]*(y_pos[1]-y_pos[3])/(x_pos[1]-float(x_pos[3]))) # Adjust science images nimages = self.m_image_in_port.get_shape()[-3] npix = self.m_image_in_port.get_shape()[-2] nwavelengths = len(wavelength) start_time = time.time() for i in range(nimages): im_storage = [] for j in range(nwavelengths): im_index = i * nwavelengths + j progress(im_index, nimages * nwavelengths, 'Centering the images...', start_time) if ndim == 3: image = self.m_image_in_port[i, ] elif ndim == 4: image = self.m_image_in_port[j, i, ] shift_yx = np.array([ (float(im_shape[-2]) - 1.) / 2. - y_center[j], (float(im_shape[-1]) - 1.) / 2. - x_center[j] ]) if self.m_dither: index = np.digitize(i, nframes, right=False) - 1 shift_yx[0] -= dither_y[index] shift_yx[1] -= dither_x[index] if npix % 2 == 0 and self.m_size is not None: im_tmp = np.zeros((image.shape[0] + 1, image.shape[1] + 1)) im_tmp[:-1, :-1] = image image = im_tmp shift_yx[0] += 0.5 shift_yx[1] += 0.5 im_shift = shift_image(image, shift_yx, 'spline') if self.m_size is not None: im_crop = crop_image(im_shift, None, self.m_size) im_storage.append(im_crop) else: im_storage.append(im_shift) if ndim == 3: self.m_image_out_port.append(im_storage[0], data_dim=3) elif ndim == 4: self.m_image_out_port.append(np.asarray(im_storage), data_dim=4) print(f'Center [x, y] = [{x_center}, {y_center}]') history = f'[x, y] = [{round(x_center[j], 2)}, {round(y_center[j], 2)}]' self.m_image_out_port.copy_attributes(self.m_image_in_port) self.m_image_out_port.add_history('WaffleCenteringModule', history) self.m_image_out_port.close_port()
def pca_psf_subtraction( images: np.ndarray, angles: Optional[np.ndarray], pca_number: Union[int, np.int64], scales: Optional[np.ndarray] = None, pca_sklearn: Optional[PCA] = None, im_shape: Optional[tuple] = None, indices: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]: """ Function for PSF subtraction with PCA. Parameters ---------- images : np.ndarray Stack of images. Also used as reference images if `pca_sklearn` is set to None. Should be in the original 3D shape if `pca_sklearn` is set to None or in the 2D reshaped format if `pca_sklearn` is not set to None. angles : np.ndarray, None Derotation angles (deg). The images are not derotated (e.g. for SDI) if set to None. pca_number : int Number of principal components used for the PSF model. scales : np.ndarray, None Scaling factors for SDI. Not used if set to None. pca_sklearn : sklearn.decomposition.pca.PCA, None PCA decomposition of the input data. im_shape : tuple(int, int, int), None Original shape of the stack with images. Required if `pca_sklearn` is not set to None. indices : np.ndarray, None Non-masked image indices. All pixels are used if set to None. Returns ------- np.ndarray Residuals of the PSF subtraction. np.ndarray Derotated residuals of the PSF subtraction. """ if pca_sklearn is None: pca_sklearn = PCA(n_components=pca_number, svd_solver='arpack') im_shape = images.shape if indices is None: # select the first image and get the unmasked image indices im_star = images[0, ].reshape(-1) indices = np.where(im_star != 0.)[0] # reshape the images and select the unmasked pixels im_reshape = images.reshape(im_shape[0], im_shape[1] * im_shape[2]) im_reshape = im_reshape[:, indices] # subtract mean image im_reshape -= np.mean(im_reshape, axis=0) # create pca basis pca_sklearn.fit(im_reshape) else: im_reshape = np.copy(images) # create pca representation zeros = np.zeros( (pca_sklearn.n_components - pca_number, im_reshape.shape[0])) pca_rep = np.matmul(pca_sklearn.components_[:pca_number], im_reshape.T) pca_rep = np.vstack((pca_rep, zeros)).T # create psf model psf_model = pca_sklearn.inverse_transform(pca_rep) # create original array size residuals = np.zeros((im_shape[0], im_shape[1] * im_shape[2])) # subtract the psf model if indices is None: indices = np.arange(0, im_reshape.shape[1], 1) residuals[:, indices] = im_reshape - psf_model # reshape to the original image size residuals = residuals.reshape(im_shape) # ----------- back scale images scal_cor = np.zeros(residuals.shape) if scales is not None: # check if the number of parang is equal to the number of images if residuals.shape[0] != scales.shape[0]: raise ValueError( f'The number of images ({residuals.shape[0]}) is not equal to the ' f'number of wavelengths ({scales.shape[0]}).') for i, _ in enumerate(scales): # rescaling the images swaps = scale_image(residuals[i, ], 1 / scales[i], 1 / scales[i]) npix_del = scal_cor.shape[-1] - swaps.shape[-1] if npix_del == 0: scal_cor[i, ] = swaps else: if npix_del % 2 == 0: npix_del_a = int(npix_del / 2) npix_del_b = int(npix_del / 2) else: npix_del_a = int((npix_del - 1) / 2) npix_del_b = int((npix_del + 1) / 2) scal_cor[i, npix_del_a:-npix_del_b, npix_del_a:-npix_del_b] = swaps if npix_del % 2 == 1: scal_cor[i, ] = shift_image(scal_cor[i, ], (0.5, 0.5), interpolation='spline') else: scal_cor = residuals res_rot = np.zeros(residuals.shape) if angles is not None: # Check if the number of parang is equal to the number of images if residuals.shape[0] != angles.shape[0]: raise ValueError( f'The number of images ({residuals.shape[0]}) is not equal to the ' f'number of parallactic angles ({angles.shape[0]}).') for j, item in enumerate(angles): res_rot[j, ] = rotate(scal_cor[j, ], item, reshape=False) else: res_rot = scal_cor return scal_cor, res_rot