Пример #1
0
    def __init__(self,
                 scidata,
                 t_ord,
                 p_ord,
                 lam_ord,
                 lam_grid=None,
                 lam_bounds=None,
                 sig=None,
                 mask=None,
                 thresh=1e-5):

        # Use `lam_grid` at the center of the trace if not specified
        if lam_grid is None:
            lam_grid, lam_col = grid_from_map(lam_ord, p_ord, out_col=True)
        else:
            lam_col = slice(None)

        # Save wavelength grid
        self.lam_grid = lam_grid.copy()
        self.lam_col = lam_col

        # Compute delta lambda for the grid
        self.d_lam = -np.diff(get_lam_p_or_m(lam_grid), axis=0)[0]

        # Basic parameters to save
        self.N_k = len(lam_grid)
        self.thresh = thresh

        if sig is None:
            self.sig = np.ones_like(scidata)
        else:
            self.sig = sig.copy()

        # Save PSF
        self.p_ord = p_ord.copy()

        # Save pixel wavelength
        self.lam_ord = lam_ord.copy()

        # Throughput
        # Can be a callable (function) or an array
        # with the same length as lambda grid.
        try:  # First assume it's a function
            self.t_ord = t_ord(self.lam_grid)  # Project on grid
        except TypeError:  # Assume it's an array
            self.t_ord = t_ord.copy()

        # Build global mask
        self.mask = self._get_mask(mask)

        # Assign other trivial attributes
        self.data = scidata.copy()
        # TODO: try setting to np.nan instead?
        self.data[self.mask] = 0
def test_ker_params(width_list, fwhm_list, n_os):

    # List of orders to consider in the extraction
    order_list = [1]

    path = "../jwst-mtl/SOSS/extract/Ref_files/"

    #### Wavelength solution ####
    wave_maps = []
    wave_maps.append(fits.getdata(path + "wavelengths_m1.fits"))

    #### Spatial profiles ####
    spat_pros = []
    spat_pros.append(fits.getdata(path + "spat_profile_m1.fits").squeeze())

    # Convert data from fits files to float (fits precision is 1e-8)
    wave_maps = [wv.astype('float64') for wv in wave_maps]
    spat_pros = [p_ord.astype('float64') for p_ord in spat_pros]

    # no tilt
    wave_maps = [
        np.tile(wv_map[50], (wv_map.shape[0], 1)) for wv_map in wave_maps
    ]

    #### Throughputs ####
    def fct_ones(x):
        return 1 / x

    thrpt_list = [fct_ones for order in order_list]

    ### Save original kernel
    webbker_file = "spectral_kernel_matrix_os_10_width_21pixels.fits"
    webbker_file = "../jwst-mtl/SOSS/extract/Ref_files/spectral_kernel_matrix/" + webbker_file
    webbker = fits.open(webbker_file)

    ### Set WebbKer class parameters to local test
    dummy_file = "spectral_kernel_matrix_test.fits"
    WebbKer.file_frame = dummy_file
    WebbKer.path = ''

    ### Init figure ###
    fig, ax = plt.subplots(len(width_list),
                           1,
                           sharex=True,
                           figsize=(12, 3 * len(width_list)))
    try:
        ax[0]
    except:
        ax = [ax]

    for i_ax, width in enumerate(width_list):
        print(f"width {i_ax+1}/{len(ax)}")
        for fwhm in fwhm_list:

            # Generate new kernel file
            kernels = cut_ker_box(webbker[0].data[0],
                                  width=width,
                                  n_os=10,
                                  fwhm=fwhm)
            hdu = fits.PrimaryHDU(np.array([kernels, webbker[0].data[1]]),
                                  header=webbker[0].header)
            hdu.writeto(dummy_file, overwrite=True)

            #### Convolution kernels ####
            # wv_map_ker = wave_maps[0][50]#grid_from_map(wave_maps[0], spat_pros[0])
            # ker_list = [WebbKer(wv_map_ker[None, :])]
            ker_list = [WebbKer(wave_maps[0])]
            # ker_list = [GaussKer(np.linspace(0.5, 3.0, 10000), res=800) for wv_map in wave_maps]

            # Put all inputs from reference files in a list
            ref_files_args = [spat_pros, wave_maps, thrpt_list, ker_list]

            def flux_fct(wv):
                return 1e5 - 1e4 * wv

            # Grid for simulation
            lam_simu = grid_from_map(wave_maps[0][50:51],
                                     spat_pros[0][50:51],
                                     n_os=10,
                                     wv_range=[0.8, 3.0])
            # lam_simu = grid_from_map(wave_maps[0], spat_pros[0], n_os=10, wv_range=[0.8, 3.0])

            # Init simu
            simu = TrpzOverlap(*ref_files_args,
                               lam_grid=lam_simu,
                               thresh=1e-8,
                               orders=[1],
                               c_kwargs={
                                   'n_out': [5 * 10, 8 * 10],
                                   'length': 21 * 10 + 1
                               })

            f_c_th = simu.c_list[0].dot(flux_fct(simu.lam_grid))
            wv_th = simu.lam_grid_c(0)
            fct_f_c_th = interp1d(wv_th,
                                  f_c_th,
                                  bounds_error=False,
                                  fill_value=np.nan,
                                  kind='cubic')

            scidata = simu.rebuild(flux_fct, orders=[0])

            # Grid
            lam_simu = grid_from_map(wave_maps[0][50:51],
                                     spat_pros[0][50:51],
                                     n_os=n_os,
                                     wv_range=[0.8, 3.0])

            # Init simu
            length_ker = 21 * n_os + ((21 * n_os) % 2 == 0)
            simu = TrpzOverlap(*ref_files_args,
                               lam_grid=lam_simu,
                               thresh=1e-8,
                               orders=[1],
                               lam_bounds=[[0.88, 2.8]],
                               c_kwargs={
                                   'thresh_out': 1e-12,
                                   'length': length_ker
                               })

            f_c = simu.c_list[0].dot(flux_fct(lam_simu))

            ax[i_ax].plot(simu.lam_grid_c(0),
                          (f_c - fct_f_c_th(simu.lam_grid_c(0))) / f_c,
                          label=fwhm)

        ax[i_ax].set_title(f"Oversampling: {n_os}, box width: {width}")
        y_lim = ax[i_ax].get_ylim()
        ax[i_ax].vlines(ker_list[0].wv_center,
                        *y_lim,
                        alpha=0.2,
                        linestyle="--")
        ax[i_ax].set_ylim(*y_lim)
        ax[i_ax].set_ylabel("convolution rel error (f_c - f_c_th)")
        ax[i_ax].legend(title="FWHM")

    ax[-1].set_xlabel("Wavelength [um]")
    plt.tight_layout()
    #     fig.savefig(f"Convolution_problem/conv_error_n_os_{n_os}_webb_ker_negative.png")
    plt.show()
Пример #3
0
def run_tikho_tests(p_list,
                    lam_list,
                    scidata,
                    f_th_c,
                    n_os_list,
                    c_thresh_list,
                    t_mat_n_os_list,
                    factors=None,
                    file_root=None,
                    file_ext=None,
                    path=None):

    # Unpack some lists
    P1, P2 = p_list
    wv_1, wv_2 = lam_list

    # Default kwargs
    if factors is None:
        factors = 10.**(-1 * np.arange(10, 25, 0.3))

    if file_root is None:
        file_root = 'tikho_test'

    if file_ext is None:
        file_ext = '.n_os_{}.c_thresh_{:1.0e}.tikho_os_{}'

    if path is None:
        path = ''

    # Message to print
    status = 'n_os={}, c_thresh={:1.0e}, t_mat_n_os={}'

    # Iterate on grid oversampling
    for n_os in n_os_list:
        # Generate grid
        lam_grid = get_soss_grid([P1, P2], [wv_1, wv_2], n_os=n_os)

        # Iterate on convolution kernel wings threshold
        for c_thresh in c_thresh_list:
            # Init extraction object
            extra = TrpzOverlap([P1, P2], [wv_1, wv_2],
                                scidata=scidata,
                                lam_grid=lam_grid,
                                thresh=1e-5,
                                c_kwargs={'thresh': c_thresh})
            # Project injected spectrum on grid
            f_k_th = {
                'f_k_th_1': f_th_c[0](extra.lam_grid_c(0)),
                'f_k_th_2': f_th_c[1](extra.lam_grid_c(1))
            }

            # Save values that do not need to be recomputed
            wv_range = [extra.lam_grid.min(), extra.lam_grid.max()]
            # Iterate on resolution of the tikhonov matrix
            for t_mat_n_os in t_mat_n_os_list:
                # Print status
                print(status.format(n_os, c_thresh, t_mat_n_os))

                # Generate a fake wv_map to cover all wv_range with a
                # resolution `t_mat_n_os` times the resolution of order 2.
                wv_map = grid_from_map(wv_2, P2, wv_range=wv_range)
                wv_map = oversample_grid(wv_map, n_os=t_mat_n_os)
                # Build convolution matrix
                conv_ord2 = get_c_matrix(WebbKer(wv_map[None, :]),
                                         extra.lam_grid,
                                         thresh=1e-5)
                # Build tikhonov matrix
                t_mat = conv_ord2 - identity(conv_ord2.shape[0])

                # Test factors
                test_conv = extra.get_tikho_tests(factors, t_mat=t_mat)

                # Save results
                file_name = path + file_root
                file_name += file_ext.format(n_os, c_thresh, t_mat_n_os)
                to_save = {**test_conv, **f_k_th, 'grid': extra.lam_grid}
                np.savez(file_name, **to_save)
Пример #4
0
thrpt_list = [fct_ones for order in order_list]

#### Convolution kernels ####
ker_list = [np.array([0, 0, 1, 0, 0]) for wv_map in wave_maps]

# Put all inputs from reference files in a list
ref_files_args = [spat_pros, wave_maps, thrpt_list, ker_list]


def flux_fct(wv):
    return 1e5 * wv


# Grid
lam_simu = grid_from_map(wave_maps[0], spat_pros[0], n_os=15)

# Init simu
simu = TrpzOverlap(*ref_files_args, lam_grid=lam_simu, thresh=1e-8, orders=[1])

scidata = simu.rebuild(flux_fct, orders=[0])

fig, ax = plt.subplots(4, 2, figsize=(6, 10))

n_os_list = [1, 2, 3, 5, 8, 11, 14, 15]

ax = ax.ravel()
for i_os, oversample in enumerate(n_os_list):
    #
    ### Simulation ###
    #
path = "/Users/antoinedb/Models/PHOENIX_HiRes/"
file = "Z-0.0/lte06000-4.50-0.0.PHOENIX-ACES-AGSS-COND-2011-HiRes.fits"
hdu = fits.open(path + file)
flux = hdu[0].data

file = "WAVE_PHOENIX-ACES-AGSS-COND-2011.fits"
hdu = fits.open(path + file)
wv = hdu[0].data / 10000.  # Angstrom to microns

# Keep only relevant wv range
i_good = (0.5 < wv) & (wv < 3.0)
wv, flux = wv[i_good], flux[i_good]

# First convolution and resampling to reduce the length
# Build estimate of the convolution kernel
wv_grid = grid_from_map(wv_2, P2, wv_range=[0.5, 3.0], n_os=20)
# Build convolution matrix
conv_ord2 = get_c_matrix(WebbKer(wv_grid[None, :]),
                         wv[500000:500500],
                         thresh=1e-5)

# Take the same convolution kernel
# (approximation, but we are still at high oversampling)
kernel = conv_ord2[250, 180:320].toarray().squeeze()
flux = np.convolve(kernel, flux, mode='same')
flux_fct = interp1d(wv, flux, kind='cubic', bounds_error=False)
# Resample
wv = oversample_grid(wv_grid, n_os=4)
flux = flux_fct(wv)

# Build accurate convolution matrix