示例#1
0
 def update_S(self, new_S, update_spectral_radius=True):
     """ Update current eigenPSFs."""
     self.S = new_S
     # Apply degradation operator to components
     normfacs = self.flux / (np.median(self.flux) * self.sig)
     self.FdS = np.array([[
         nf * degradation_op(S_j, shift_ker, self.D)
         for nf, shift_ker in zip(normfacs, utils.reg_format(self.ker))
     ] for S_j in utils.reg_format(self.S)])
     if update_spectral_radius:
         PowerMethod.get_spec_rad(self)
示例#2
0
    def MtX(self, x):
        """Adjoint to degradation operator :func:`MX`.

        """
        normfacs = self.flux / (np.median(self.flux) * self.sig)
        x = utils.reg_format(x)
        upsamp_x = np.array([
            nf * adjoint_degradation_op(x_i, shift_ker, self.D)
            for nf, x_i, shift_ker in zip(normfacs, x,
                                          utils.reg_format(self.ker_rot))
        ])
        x, upsamp_x = utils.rca_format(x), utils.rca_format(upsamp_x)
        return utils.apply_transform(upsamp_x.dot(self.A.T), self.filters)
示例#3
0
    def MX(self, transf_S):
        """Apply degradation operator and renormalize.

        Parameters
        ----------
        transf_S : np.ndarray
            Current eigenPSFs in Starlet space.

        Returns
        -------
        np.ndarray result

        """
        normfacs = self.flux / (np.median(self.flux) * self.sig)
        S = utils.rca_format(
            np.array([
                filter_convolve(transf_Sj, self.filters, filter_rot=True)
                for transf_Sj in transf_S
            ]))
        dec_rec = np.array([
            nf * degradation_op(S.dot(A_i), shift_ker, self.D)
            for nf, A_i, shift_ker in zip(normfacs, self.A.T,
                                          utils.reg_format(self.ker))
        ])
        self._current_rec = utils.rca_format(dec_rec)
        return self._current_rec
示例#4
0
 def validation_stars(self, test_stars, test_pos):
     """ Match PSF model to stars - in flux, shift and pixel sampling - for validation tests.
     Returns both the matched PSFs' stamps and chi-square value.
     
     Parameters
     ----------
     test_stars: np.ndarray
         Star stamps to be used for comparison with the PSF model. Should be in "rca" format, 
         i.e. with axises (n_pixels, n_pixels, n_stars).
     test_pos: np.ndarray
         Their corresponding positions.
     """
     if not self.is_fitted:
         raise ValueError(
             'RCA instance has not yet been fitted to observations. Please run\
         the fit method.')
     cents = []
     for star in utils.reg_format(test_stars):
         cents += [utils.CentroidEstimator(star, sig=self.psf_size)]
     test_shifts = np.array([ce.return_shifts() for ce in cents])
     test_fluxes = utils.flux_estimate_stack(test_stars, rad=4)
     matched_psfs = self.estimate_psf(test_pos,
                                      apply_degradation=True,
                                      shifts=test_shifts,
                                      flux=test_fluxes)
     return matched_psfs
示例#5
0
    def MtX(self, x):
        """Adjoint to degradation operator :func:`MX`.

        Parameters
        ----------
        x : np.ndarray
            Set of finer-grid images.
        """
        x = utils.reg_format(x)
        STx = np.array([np.sum(FdS_i * x, axis=(1, 2)) for FdS_i in self.FdS])
        return STx.dot(self.VT.T)  #aka... "V"
示例#6
0
 def _initialize(self):
     """ Initialization tasks related to noise levels, shifts and flux. Note it includes
     renormalizing observed data, so needs to be ran even if all three are provided."""
     if self.default_filters:
         init_filters = get_mr_filters(self.shap[:2],
                                       opt=self.opt,
                                       coarse=True)
     else:
         init_filters = self.Phi_filters
     # noise levels
     if self.sigs is None:
         transf_data = utils.apply_transform(self.obs_data, init_filters)
         transf_mask = utils.transform_mask(self.obs_weights,
                                            init_filters[0])
         sigmads = np.array([
             1.4826 * utils.mad(fs[0], w)
             for fs, w in zip(transf_data, utils.reg_format(transf_mask))
         ])
         self.sigs = sigmads / np.linalg.norm(init_filters[0])
     else:
         self.sigs = np.copy(self.sigs)
     self.sig_min = np.min(self.sigs)
     # intra-pixel shifts
     if self.shifts is None:
         thresh_data = np.copy(self.obs_data)
         cents = []
         for i in range(self.shap[2]):
             # don't allow thresholding to be over 80% of maximum observed pixel
             nsig_shifts = min(
                 self.ksig_init,
                 0.8 * self.obs_data[:, :, i].max() / self.sigs[i])
             thresh_data[:, :, i] = utils.HardThresholding(
                 thresh_data[:, :, i], nsig_shifts * self.sigs[i])
             cents += [
                 utils.CentroidEstimator(thresh_data[:, :, i],
                                         sig=self.psf_size)
             ]
         self.shifts = np.array([ce.return_shifts() for ce in cents])
     self.shift_ker_stack, self.shift_ker_stack_adj = utils.shift_ker_stack(
         self.shifts, self.upfact)
     # flux levels
     if self.flux is None:
         #TODO: could actually pass on the centroids to flux estimator since we have them at this point
         self.flux = utils.flux_estimate_stack(self.obs_data, rad=4)
     self.flux_ref = np.median(self.flux)
     # Normalize noise levels observed data
     self.sigs /= self.sig_min
     self.obs_data /= self.sigs.reshape(1, 1, -1)
示例#7
0
    def _fit(self):
        weights = self.A
        comp = self.S
        alpha = self.alpha
        #### Source updates set-up ####
        # initialize dual variable and compute Starlet filters for Condat source updates
        dual_var = np.zeros((self.im_hr_shape))
        if self.default_filters:
            self.Phi_filters = get_mr_filters(self.im_hr_shape[:2],
                                              opt=self.opt,
                                              coarse=True)
        rho_phi = np.sqrt(
            np.sum(np.sum(np.abs(self.Phi_filters), axis=(1, 2))**2))

        # Set up source updates, starting with the gradient
        source_grad = grads.SourceGrad(self.obs_data, self.obs_weights,
                                       weights, self.flux, self.sigs,
                                       self.shift_ker_stack,
                                       self.shift_ker_stack_adj, self.upfact,
                                       self.Phi_filters)

        # sparsity in Starlet domain prox (this is actually assuming synthesis form)
        sparsity_prox = rca_prox.StarletThreshold(
            0)  # we'll update to the actual thresholds later

        # and the linear recombination for the positivity constraint
        lin_recombine = rca_prox.LinRecombine(weights, self.Phi_filters)

        #### Weight updates set-up ####
        # gradient
        weight_grad = grads.CoeffGrad(self.obs_data, self.obs_weights, comp,
                                      self.VT, self.flux, self.sigs,
                                      self.shift_ker_stack,
                                      self.shift_ker_stack_adj, self.upfact)

        # cost function
        weight_cost = costObj([weight_grad], verbose=self.modopt_verb)
        source_cost = costObj([source_grad], verbose=self.modopt_verb)

        # k-thresholding for spatial constraint
        iter_func = lambda x: np.floor(np.sqrt(x)) + 1
        coeff_prox = rca_prox.KThreshold(iter_func)

        for k in range(self.nb_iter):
            #### Eigenpsf update ####
            # update gradient instance with new weights...
            source_grad.update_A(weights)

            # ... update linear recombination weights...
            lin_recombine.update_A(weights)

            # ... set optimization parameters...
            beta = source_grad.spec_rad + rho_phi
            tau = 1. / beta
            sigma = 1. / lin_recombine.norm * beta / 2

            # ... update sparsity prox thresholds...
            thresh = utils.reg_format(
                utils.acc_sig_maps(self.shap,
                                   self.shift_ker_stack_adj,
                                   self.sigs,
                                   self.flux,
                                   self.flux_ref,
                                   self.upfact,
                                   weights,
                                   sig_data=np.ones(
                                       (self.shap[2], )) * self.sig_min))
            thresholds = self.ksig * np.sqrt(
                np.array([
                    filter_convolve(Sigma_k**2, self.Phi_filters**2)
                    for Sigma_k in thresh
                ]))

            sparsity_prox.update_threshold(tau * thresholds)

            # and run source update:
            transf_comp = utils.apply_transform(comp, self.Phi_filters)
            if self.nb_reweight:
                reweighter = cwbReweight(thresholds)
                for _ in range(self.nb_reweight):
                    source_optim = optimalg.Condat(transf_comp,
                                                   dual_var,
                                                   source_grad,
                                                   sparsity_prox,
                                                   Positivity(),
                                                   linear=lin_recombine,
                                                   cost=source_cost,
                                                   max_iter=self.nb_subiter_S,
                                                   tau=tau,
                                                   sigma=sigma)
                    transf_comp = source_optim.x_final
                    reweighter.reweight(transf_comp)
                    thresholds = reweighter.weights
            else:
                source_optim = optimalg.Condat(transf_comp,
                                               dual_var,
                                               source_grad,
                                               sparsity_prox,
                                               Positivity(),
                                               linear=lin_recombine,
                                               cost=source_cost,
                                               max_iter=self.nb_subiter_S,
                                               tau=tau,
                                               sigma=sigma)
                transf_comp = source_optim.x_final
            comp = utils.rca_format(
                np.array([
                    filter_convolve(transf_compj, self.Phi_filters, True)
                    for transf_compj in transf_comp
                ]))

            #TODO: replace line below with Fred's component selection (to be extracted from `low_rank_global_src_est_comb`)
            ind_select = range(comp.shape[2])

            #### Weight update ####
            if k < self.nb_iter - 1:
                # update sources and reset iteration counter for K-thresholding
                weight_grad.update_S(comp)
                coeff_prox.reset_iter()
                weight_optim = optimalg.ForwardBackward(
                    alpha,
                    weight_grad,
                    coeff_prox,
                    cost=weight_cost,
                    beta_param=weight_grad.inv_spec_rad,
                    auto_iterate=False)
                weight_optim.iterate(max_iter=self.nb_subiter_weights)
                alpha = weight_optim.x_final
                weights_k = alpha.dot(self.VT)

                # renormalize to break scale invariance
                weight_norms = np.sqrt(np.sum(weights_k**2, axis=1))
                comp *= weight_norms
                weights_k /= weight_norms.reshape(-1, 1)
                #TODO: replace line below with Fred's component selection
                ind_select = range(weights.shape[0])
                weights = weights_k[ind_select, :]
                supports = None  #TODO

        self.A = weights
        self.S = comp
        self.alpha = alpha
        source_grad.MX(transf_comp)
        self.current_rec = source_grad._current_rec
示例#8
0
 def estimate_psf(self,
                  test_pos,
                  n_neighbors=15,
                  rbf_function='thin_plate',
                  apply_degradation=False,
                  shifts=None,
                  flux=None,
                  upfact=None,
                  rca_format=False):
     """ Estimate and return PSF at desired positions.
     
     Parameters
     ----------
     test_pos: np.ndarray
         Positions where the PSF should be estimated. Should be in the same format (units,
         etc.) as the ``obs_pos`` fed to :func:`RCA.fit`.
     n_neighbors: int
         Number of neighbors to use for RBF interpolation. Default is 15.
     rbf_function: str
         Type of RBF kernel to use. Default is ``'thin_plate'``.
     apply_degradation: bool
         Whether PSF model should be degraded (shifted and resampled on coarse grid), 
         for instance for comparison with stars. If True, expects shifts to be provided.
         Default is False.
     shifts: np.ndarray
         Intra-pixel shifts to apply if ``apply_degradation`` is set to True.
     flux: np.ndarray
         Flux levels by which reconstructed PSF will be multiplied if provided. For comparison with 
         stars if ``apply_degradation`` is set to True. 
     upfact: int
         Upsampling factor; default is None, in which case that of the RCA instance will be used.
     rca_format: bool
         If True, returns the PSF model in "rca" format, i.e. with axises
         (n_pixels, n_pixels, n_stars). Otherwise, and by default, return them in
         "regular" format, (n_stars, n_pixels, n_pixels).
     """
     if not self.is_fitted:
         raise ValueError(
             'RCA instance has not yet been fitted to observations. Please run\
         the fit method.')
     if upfact is None:
         upfact = self.upfact
     ntest = test_pos.shape[0]
     test_weights = np.empty((self.n_comp, ntest))
     for j, pos in enumerate(test_pos):
         # determine neighbors
         nbs, pos_nbs = utils.return_neighbors(pos, self.obs_pos, self.A.T,
                                               n_neighbors)
         # train RBF and interpolate for each component
         for i in range(self.n_comp):
             rbfi = Rbf(pos_nbs[:, 0],
                        pos_nbs[:, 1],
                        nbs[:, i],
                        function=rbf_function)
             test_weights[i, j] = rbfi(pos[0], pos[1])
     PSFs = self._transform(test_weights)
     if apply_degradation:
         shift_kernels, _ = utils.shift_ker_stack(shifts, self.upfact)
         deg_PSFs = np.array([
             grads.degradation_op(PSFs[:, :, j], shift_kernels[:, :, j],
                                  upfact) for j in range(ntest)
         ])
         if flux is not None:
             deg_PSFs *= flux.reshape(-1, 1, 1) / self.flux_ref
         if rca_format:
             return utils.rca_format(deg_PSFs)
         else:
             return deg_PSFs
     elif rca_format:
         return PSFs
     else:
         return utils.reg_format(PSFs)