Exemplo n.º 1
0
def test_trace_polynomial():
    """Test trace_polynomial function"""
    # No order specified
    assert len(lt.trace_polynomial(order=None, evaluate=False)) == 2

    # Single order
    assert len(lt.trace_polynomial(order=1, evaluate=False)) == 5

    # Test evaluate
    assert len(lt.trace_polynomial(order=1, evaluate=True)) == 2048
Exemplo n.º 2
0
    def plot(self, scale='linear', coeffs=None, **kwargs):
        """
        Plot the frames of data

        Parameters
        ----------
        scale: str
            The scale to plot, ['linear', 'log']
        coeffs: sequence
            The polynomial coefficients of the traces
        """
        # Reshape the data
        dim = self.data.shape
        if self.data.ndim == 4:
            data = self.data.reshape(dim[0] * dim[1], dim[2], dim[3])
        elif self.data.ndim == 2:
            data = self.data.reshape(1, dim[0], dim[1])
        else:
            data = self.data

        # Make the figure
        title = '{} Frames'.format(self.ext)
        coeffs = lt.trace_polynomial()
        fig = plt.plot_frames(data,
                              scale=scale,
                              trace_coeffs=coeffs,
                              wavecal=self.wavecal,
                              title=title,
                              **kwargs)

        return fig
Exemplo n.º 3
0
    def subarray(self, subarr):
        """Setter for the subarray

        Properties
        ----------
        subarr: str
            The name of the subarray to use,
            ['SUBSTRIP256', 'SUBSTRIP96', 'FULL']
        """
        subs = ['SUBSTRIP256', 'SUBSTRIP96', 'FULL']

        # Check the value
        if subarr not in subs:
            raise ValueError("'{}' not a supported subarray. Try {}".format(
                subarr, subs))

        # Set the subarray
        self._subarray = subarr
        self.subarray_specs = utils.subarray_specs(subarr)

        # Set the dependent quantities
        self._ncols = 2048
        self._nrows = self.subarray_specs.get('y')
        self.wave = utils.wave_solutions(subarr)
        self.avg_wave = np.mean(self.wave, axis=1)
        self.coeffs = locate_trace.trace_polynomial(subarray=subarr)

        # Reset the data and time arrays
        self._reset_data()
        self._reset_time()

        # Reset the psfs
        self._reset_psfs()
Exemplo n.º 4
0
    def plot(self,
             idx=0,
             scale='linear',
             order=None,
             noise=True,
             traces=False,
             saturation=0.8,
             draw=True):
        """
        Plot a TSO frame

        Parameters
        ----------
        idx: int
            The frame index to plot
        scale: str
            Plot scale, ['linear', 'log']
        order: sequence
            The order to isolate
        noise: bool
            Plot with the noise model
        traces: bool
            Plot the traces used to generate the frame
        saturation: float
            The fraction of full well defined as saturation
        draw: bool
            Render the figure instead of returning it
        """
        # Get the data cube
        if order in [1, 2]:
            tso = getattr(self, 'tso_order{}_ideal'.format(order))
        else:
            if noise:
                tso = self.tso
            else:
                tso = self.tso_ideal

        # Reshape data
        tso.shape = self.dims3

        # Set the plot args
        wavecal = self.wave
        title = '{} - Frame {}'.format(self.title, idx)
        coeffs = locate_trace.trace_polynomial() if traces else None

        # Plot the frame
        fig = plotting.plot_frames(data=tso,
                                   idx=idx,
                                   scale=scale,
                                   trace_coeffs=coeffs,
                                   saturation=saturation,
                                   title=title,
                                   wavecal=wavecal)

        if draw:
            show(fig)
        else:
            return fig
Exemplo n.º 5
0
def SOSS_psf_cube(filt='CLEAR', order=1, subarray='SUBSTRIP256', generate=False, mprocessing=True, wave_sol=None, dirname='default'):
    """
    Generate/retrieve a data cube of shape (3, 2048, 76, 76) which is a
    76x76 pixel psf for 2048 wavelengths for each trace order. The PSFs
    are scaled to unity and rotated to reproduce the trace tilt at each
    wavelength then placed on the desired subarray.

    Parameters
    ----------
    filt: str
        The filter to use, ['CLEAR', 'F277W']
    order: int
        The trace order
    subarray: str
        The subarray to use, ['SUBSTRIP96', 'SUBSTRIP256', 'FULL']
    generate: bool
        Generate a new cube
    mprocessing: bool
        Use multiprocessing
    wave_sol: sequence (optional)
        The user provided wavelength solutions for orders 1, 2, and 3
    dirname: str (optional)
        The target subdirectory name for the PSFs that use a custom wavelength solution

    Returns
    -------
    np.ndarray
        An array of the SOSS psf at 2048 wavelengths for each order
    """
    # Check if it's a custom wavelength solution
    psf_loc = copy(PSF_DIR)
    if dirname != 'default':
        dirpath = os.path.join(PSF_DIR, dirname)
        if not os.path.exists(dirpath):
            os.system('mkdir {}'.format(dirpath))
        psf_loc = psf_loc.replace('soss_psfs', 'soss_psfs/{}'.format(dirname))

    if generate:

        print('This takes about 2 minutes.')

        # Default wavelengths
        if wave_sol is None:
            wavelengths = np.mean(utils.wave_solutions(subarray), axis=1)[:3 if filt == 'CLEAR' else 1]

        # Or user provided
        else:
            if wave_sol.shape != (3, 2048):
                raise TypeError("'wave_sol' input must be an array of shape (3, 2048)")
            wavelengths = wave_sol

        # Get trace polynomial coefficients
        coeffs = locate_trace.trace_polynomial(subarray)

        # Get the file
        psf_file = os.path.join(PSF_DIR, 'SOSS_{}_PSF.fits'.format(filt))

        # Load the SOSS psf cube
        cube = fits.getdata(psf_file).swapaxes(-1, -2)
        wave = fits.getdata(psf_file, ext=1)

        # Initilize interpolator
        psfs = interp1d(wave, cube, axis=0, kind=3)
        trace_cols = np.arange(2048)

        # Run datacube
        for n, wavelength in enumerate(wavelengths):

            # Evaluate the trace polynomial in each column to get the y-position of the trace center
            trace_centers = np.polyval(coeffs[n], trace_cols)

            # Don't calculate order2 or order 3 for F277W
            if n > 0 and filt.lower() == 'f277w':
                pass

            else:

                # Get the psf for each column
                print('Calculating order {} SOSS psfs for {} filter...'.format(n + 1, filt))
                start = time.time()
                func = partial(get_SOSS_psf, filt=filt, psfs=psfs)

                if mprocessing:
                    pool = multiprocessing.Pool(8)
                    raw_psfs = np.array(pool.map(func, wavelength))
                    pool.close()
                    pool.join()
                    del pool
                else:
                    raw_psfs = []
                    for i in range(len(wavelength)):
                        raw_psfs.append(func(wavelength[i]))
                    raw_psfs = np.array(raw_psfs)

                print('Finished in {} seconds.'.format(time.time()-start))

                # Rotate the psfs
                print('Rotating order {} SOSS psfs for {} filter...'.format(n + 1, filt))
                start = time.time()
                func = partial(rotate, reshape=False)

                # Get the PSF tilt at each column
                angles = psf_tilts(order)

                if mprocessing:
                    pool = multiprocessing.Pool(8)
                    rotated_psfs = np.array(pool.starmap(func, zip(raw_psfs, angles)))
                    pool.close()
                    pool.join()
                    del pool
                else:
                    rotated_psfs = []
                    for rp, ang in zip(raw_psfs, angles):
                        rotated_psfs.append(func(rp, ang))
                    rotated_psfs = np.array(rotated_psfs)

                print('Finished in {} seconds.'.format(time.time()-start))

                # Scale psfs to 1
                rotated_psfs = np.abs(rotated_psfs)
                scale = np.nansum(rotated_psfs, axis=(1, 2))[:, None, None]
                rotated_psfs = rotated_psfs / scale

                # Split it into 4 chunks to be below Github file size limit
                chunks = rotated_psfs.reshape(4, 512, 76, 76)
                for N, chunk in enumerate(chunks):

                    idx0 = N * 512
                    idx1 = idx0 + 512
                    centers = trace_centers[idx0:idx1]

                    # Interpolate the psfs onto the subarray
                    print('Interpolating chunk {}/4 for order {} SOSS psfs for {} filter onto subarray...'.format(N + 1, n + 1, filt))
                    start = time.time()
                    func = put_psf_on_subarray

                    if mprocessing:
                        pool = multiprocessing.Pool(8)
                        data = zip(chunk, centers)
                        subarray_psfs = pool.starmap(func, data)
                        pool.close()
                        pool.join()
                        del pool
                    else:
                        subarray_psfs = []
                        for ch, ce in zip(chunk, centers):
                            subarray_psfs.append(func(ch, ce))

                    print('Finished in {} seconds.'.format(time.time()-start))

                    # Get the filepath
                    file = os.path.join(psf_loc, 'SOSS_{}_PSF_order{}_{}.npy'.format(filt, n+1, N+1))

                    # Delete the file if it exists
                    if os.path.isfile(file):
                        os.system('rm {}'.format(file))

                    # Write the data
                    np.save(file, np.array(subarray_psfs))

                    print('Data saved to', file)

    else:

        if PSF_DIR is None:
            print("No PSF files detected. Using all ones.")
            subarr = 256 if subarray == 'SUBSTRIP256' else 96 if subarray == 'SUBSTRIP256' else 2048
            return np.ones((2048, subarr, 76))

        else:

            print("Using SOSS PSF files located at {}".format(psf_loc))

            # Get the chunked data and concatenate
            full_data = []
            for chunk in [1, 2, 3, 4]:
                file = os.path.join(psf_loc, 'SOSS_{}_PSF_order{}_{}.npy'.format(filt, order, chunk))
                full_data.append(np.load(file))

            return np.concatenate(full_data, axis=0)
Exemplo n.º 6
0
def calculate_psf_tilts():
    """
    Calculate the tilt of the psf at the center of each column
    using all binned pixels in the given wavelength calibration file
    for both orders and save to file
    """
    for order in [1, 2, 3]:

        # Get the file
        psf_file = os.path.join(PSF_DIR, 'SOSS_PSF_tilt_order{}.npy'.format(order))

        # Dimensions
        subarray = 'SUBSTRIP256'
        X = range(2048)
        Y = range(256)

        # Get the wave map
        wave_map = utils.wave_solutions(subarray, order).astype(float)

        # Get the y-coordinate of the trace polynomial in this column
        # (center of the trace)
        coeffs = locate_trace.trace_polynomial(subarray=subarray, order=order)
        trace = np.polyval(coeffs, X)

        # Interpolate to get the wavelength value at the center
        wave = interp2d(X, Y, wave_map)

        # Get the wavelength of the trace center in each column
        trace_wave = []
        for x, y in zip(X, trace):
            trace_wave.append(wave(x, y)[0])

        # For each column wavelength (defined by the wavelength at
        # the trace center) define an isowavelength contour
        angles = []
        for n, x in enumerate(X):

            w = trace_wave[x]

            # Edge cases
            try:
                w0 = trace_wave[x-1]
            except IndexError:
                w0 = 0

            try:
                w1 = trace_wave[x+1]
            except IndexError:
                w1 = 10

            # Define the width of the wavelength bin as half-way
            # between neighboring points
            dw0 = np.mean([w0, w])
            dw1 = np.mean([w1, w])

            # Get the coordinates of all the pixels in that range
            yy, xx = np.where(np.logical_and(wave_map >= dw0, wave_map < dw1))

            # Find the angle between the vertical and the tilted wavelength bin
            if len(xx) >= 1:
                angle = get_angle([xx[-1], yy[-1]], [x, trace[x]])
            else:
                angle = 0

            # Don't flip them upside down
            angle = angle % 180

            # Add to the array
            angles.append(angle)

        # Save the file
        np.save(psf_file, np.array(angles))
        print('Angles saved to', psf_file)
Exemplo n.º 7
0
 def test_coeffs(self):
     """Check the traces are drawn"""
     coeffs = lt.trace_polynomial()
     fig = plt.plot_frames(self.frames, scale='log', trace_coeffs=coeffs)
     self.assertEqual(str(type(fig)), "<class 'bokeh.models.layouts.Column'>")