Ejemplo n.º 1
0
    def _run_single_processing(self, star_reshape: np.ndarray, im_shape: tuple,
                               indices: Optional[np.ndarray]) -> None:
        """
        Internal function to create the residuals, derotate the images, and write the output
        using a single process.
        """

        start_time = time.time()

        # Get the parallactic angles
        parang = -1. * self.m_star_in_port.get_attribute(
            'PARANG') + self.m_extra_rot

        if self.m_ifs_data:
            # Get the wavelengths
            if 'WAVELENGTH' in self.m_star_in_port.get_all_non_static_attributes(
            ):
                wavelength = self.m_star_in_port.get_attribute('WAVELENGTH')

            else:
                raise ValueError(
                    'The wavelengths are not found. These should be stored '
                    'as the \'WAVELENGTH\' attribute.')

            # Calculate the wavelength ratios
            scales = scaling_factors(wavelength)

        else:
            scales = None

        if self.m_processing_type in ['ADI', 'SDI']:
            pca_first = self.m_components
            pca_secon = [-1]  # Not used

        elif self.m_processing_type in ['SDI+ADI', 'ADI+SDI']:
            pca_first = self.m_components[0]
            pca_secon = self.m_components[1]

        # Setup output arrays

        out_array_res = np.zeros(im_shape)

        if self.m_ifs_data:
            if self.m_processing_type in ['ADI', 'SDI']:
                res_shape = (len(pca_first), len(wavelength), im_shape[-2],
                             im_shape[-1])

            elif self.m_processing_type in ['SDI+ADI', 'ADI+SDI']:
                res_shape = (len(pca_first), len(pca_secon), len(wavelength),
                             im_shape[-2], im_shape[-1])

        else:
            res_shape = (len(pca_first), im_shape[-2], im_shape[-1])

        out_array_mean = np.zeros(res_shape)
        out_array_medi = np.zeros(res_shape)
        out_array_weig = np.zeros(res_shape)
        out_array_clip = np.zeros(res_shape)

        # loop over all different combination of pca_numbers and applying the reductions
        for i, pca_1 in enumerate(pca_first):
            for j, pca_2 in enumerate(pca_secon):
                progress(i + j,
                         len(pca_first) + len(pca_secon),
                         'Creating residuals...', start_time)

                # process images
                residuals, res_rot = postprocessor(
                    images=star_reshape,
                    angles=parang,
                    scales=scales,
                    pca_number=(pca_1, pca_2),
                    pca_sklearn=self.m_pca,
                    im_shape=im_shape,
                    indices=indices,
                    processing_type=self.m_processing_type)

                # 1.) derotated residuals
                if self.m_res_arr_out_ports is not None:
                    if not self.m_ifs_data:
                        self.m_res_arr_out_ports[pca_1].set_all(res_rot)
                        self.m_res_arr_out_ports[pca_1].copy_attributes(
                            self.m_star_in_port)
                        self.m_res_arr_out_ports[pca_1].add_history(
                            'PcaPsfSubtractionModule',
                            f'max PC number = {pca_first}')

                    else:
                        out_array_res = residuals

                # 2.) mean residuals
                if self.m_res_mean_out_port is not None:
                    if self.m_processing_type in ['SDI+ADI', 'ADI+SDI']:
                        out_array_mean[i,
                                       j] = combine_residuals(method='mean',
                                                              res_rot=res_rot,
                                                              angles=parang)

                    else:
                        out_array_mean[i] = combine_residuals(method='mean',
                                                              res_rot=res_rot,
                                                              angles=parang)

                # 3.) median residuals
                if self.m_res_median_out_port is not None:
                    if self.m_processing_type in ['SDI+ADI', 'ADI+SDI']:
                        out_array_medi[i,
                                       j] = combine_residuals(method='median',
                                                              res_rot=res_rot,
                                                              angles=parang)

                    else:
                        out_array_medi[i] = combine_residuals(method='median',
                                                              res_rot=res_rot,
                                                              angles=parang)

                # 4.) noise-weighted residuals
                if self.m_res_weighted_out_port is not None:
                    if self.m_processing_type in ['SDI+ADI', 'ADI+SDI']:
                        out_array_weig[i, j] = combine_residuals(
                            method='weighted',
                            res_rot=res_rot,
                            residuals=residuals,
                            angles=parang)

                    else:
                        out_array_weig[i] = combine_residuals(
                            method='weighted',
                            res_rot=res_rot,
                            residuals=residuals,
                            angles=parang)

                # 5.) clipped mean residuals
                if self.m_res_rot_mean_clip_out_port is not None:
                    if self.m_processing_type in ['SDI+ADI', 'ADI+SDI']:
                        out_array_clip[i,
                                       j] = combine_residuals(method='clipped',
                                                              res_rot=res_rot,
                                                              angles=parang)

                    else:
                        out_array_clip[i] = combine_residuals(method='clipped',
                                                              res_rot=res_rot,
                                                              angles=parang)

        # Configurate data output according to the processing type
        # 1.) derotated residuals
        if self.m_res_arr_out_ports is not None and self.m_ifs_data:
            if pca_secon[0] == -1:
                history = f'max PC number = {pca_first}'

            else:
                history = f'max PC number = {pca_first} / {pca_secon}'

            # squeeze out_array_res to reduce dimensionallity as the residuals of
            # SDI+ADI and ADI+SDI are always of the form (1, 1, ...)
            squeezed = np.squeeze(out_array_res)

            if isinstance(self.m_components, tuple):
                self.m_res_arr_out_ports.set_all(squeezed,
                                                 data_dim=squeezed.ndim)
                self.m_res_arr_out_ports.copy_attributes(self.m_star_in_port)
                self.m_res_arr_out_ports.add_history('PcaPsfSubtractionModule',
                                                     history)

            else:
                for i, pca in enumerate(self.m_components):
                    self.m_res_arr_out_ports[pca].append(squeezed[i])
                    self.m_res_arr_out_ports[pca].add_history(
                        'PcaPsfSubtractionModule', history)

        # 2.) mean residuals
        if self.m_res_mean_out_port is not None:
            self.m_res_mean_out_port.set_all(out_array_mean,
                                             data_dim=out_array_mean.ndim)

        # 3.) median residuals
        if self.m_res_median_out_port is not None:
            self.m_res_median_out_port.set_all(out_array_medi,
                                               data_dim=out_array_medi.ndim)

        # 4.) noise-weighted residuals
        if self.m_res_weighted_out_port is not None:
            self.m_res_weighted_out_port.set_all(out_array_weig,
                                                 data_dim=out_array_weig.ndim)

        # 5.) clipped mean residuals
        if self.m_res_rot_mean_clip_out_port is not None:
            self.m_res_rot_mean_clip_out_port.set_all(
                out_array_clip, data_dim=out_array_clip.ndim)
Ejemplo n.º 2
0
    def _run_multi_processing(self, star_reshape: np.ndarray, im_shape: tuple,
                              indices: Optional[np.ndarray]) -> None:
        """
        Internal function to create the residuals, derotate the images, and write the output
        using multiprocessing.
        """

        cpu = self._m_config_port.get_attribute('CPU')
        parang = -1. * self.m_star_in_port.get_attribute(
            'PARANG') + self.m_extra_rot

        if self.m_ifs_data:
            if 'WAVELENGTH' in self.m_star_in_port.get_all_non_static_attributes(
            ):
                wavelength = self.m_star_in_port.get_attribute('WAVELENGTH')

            else:
                raise ValueError(
                    'The wavelengths are not found. These should be stored '
                    'as the \'WAVELENGTH\' attribute.')

            scales = scaling_factors(wavelength)

        else:
            scales = None

        if self.m_processing_type in ['ADI', 'SDI']:
            pca_first = self.m_components
            pca_secon = [-1]  # Not used

        elif self.m_processing_type in ['SDI+ADI', 'ADI+SDI']:
            pca_first = self.m_components[0]
            pca_secon = self.m_components[1]

        if self.m_ifs_data:
            if self.m_processing_type in ['ADI', 'SDI']:
                res_shape = (len(pca_first), len(wavelength), im_shape[-2],
                             im_shape[-1])

            elif self.m_processing_type in ['SDI+ADI', 'ADI+SDI']:
                res_shape = (len(pca_first), len(pca_secon), len(wavelength),
                             im_shape[-2], im_shape[-1])

        else:
            res_shape = (len(self.m_components), im_shape[1], im_shape[2])

        tmp_output = np.zeros(res_shape)

        if self.m_res_mean_out_port is not None:
            self.m_res_mean_out_port.set_all(tmp_output, keep_attributes=False)

        if self.m_res_median_out_port is not None:
            self.m_res_median_out_port.set_all(tmp_output,
                                               keep_attributes=False)

        if self.m_res_weighted_out_port is not None:
            self.m_res_weighted_out_port.set_all(tmp_output,
                                                 keep_attributes=False)

        if self.m_res_rot_mean_clip_out_port is not None:
            self.m_res_rot_mean_clip_out_port.set_all(tmp_output,
                                                      keep_attributes=False)

        self.m_star_in_port.close_port()
        self.m_reference_in_port.close_port()

        if self.m_res_mean_out_port is not None:
            self.m_res_mean_out_port.close_port()

        if self.m_res_median_out_port is not None:
            self.m_res_median_out_port.close_port()

        if self.m_res_weighted_out_port is not None:
            self.m_res_weighted_out_port.close_port()

        if self.m_res_rot_mean_clip_out_port is not None:
            self.m_res_rot_mean_clip_out_port.close_port()

        if self.m_basis_out_port is not None:
            self.m_basis_out_port.close_port()

        capsule = PcaMultiprocessingCapsule(
            self.m_res_mean_out_port, self.m_res_median_out_port,
            self.m_res_weighted_out_port, self.m_res_rot_mean_clip_out_port,
            cpu, deepcopy(self.m_components), deepcopy(self.m_pca),
            deepcopy(star_reshape), deepcopy(parang), deepcopy(scales),
            im_shape, indices, self.m_processing_type)

        capsule.run()