Beispiel #1
0
    def make_jpg(self):
        '''
        Converts a FITS file to JPG image
        '''

        # file to convert is lev0Dir/KOAID

        path = self.dirs['lev0']
        koaid = self.fitsHeader.get('KOAID')
        filePath = ''
        for root, dirs, files in os.walk(path):
            if koaid in files:
                filePath = ''.join((root, '/', koaid))
        if not filePath:
            self.log.error('make_jpg: Could not find KOAID: ' + koaid)
            return False
        self.log.info('make_jpg: converting {} to jpeg format'.format(filePath))

        #check if already exists? (JPG conversion is time consuming)
        #todo: Only allow skip if not fullRun? (ie Will we ever want to regenerate the jpg?)

        jpgFile = filePath.replace('.fits', '.jpg')
        if os.path.isfile(jpgFile):
            self.log.warning('make_jpg: file already exists. SKIPPING')
            return True

        # verify file exists

        try:
            if os.path.isfile(filePath):
                # image data to convert
                image = self.fitsHdu[0].data
                interval = ZScaleInterval()
                vmin, vmax = interval.get_limits(image)
                norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=AsinhStretch())
                plt.imshow(image, cmap='gray', origin='lower', norm=norm)
                plt.axis('off')
                # save as png, then convert to jpg
                pngFile = filePath.replace('.fits', '.png')
                jpgFile = pngFile.replace('.png', '.jpg')
                plt.savefig(pngFile)
                Image.open(pngFile).convert('RGB').save(jpgFile)
                os.remove(pngFile)
                plt.close()
            else:
                #TODO: if this errors, should we remove .fits file added previously?
                self.log.error('make_jpg: file does not exist {}'.format(filePath))
                return False
        except:
            self.log.error('make_jpg: Could not create JPG: ' + jpgFile)
            return False

        return True
Beispiel #2
0
    def __init__(self, data, header, **kwargs):
        if 'bunit' not in header and 'pixlunit' in header:
            # PIXLUNIT is not a FITS standard keyword
            header['bunit'] = header['pixlunit']

        super().__init__(data, header, **kwargs)

        # Fill in some missing info
        self.meta['detector'] = self.meta.get('detector', "AIA")
        self._nickname = self.detector
        self.plot_settings['cmap'] = self._get_cmap_name()
        self.plot_settings['norm'] = ImageNormalize(stretch=source_stretch(
            self.meta, AsinhStretch(0.01)),
                                                    clip=False)
Beispiel #3
0
def AsinhNorm(a=0.1):
    """Custom Arcsinh Norm.

    Parameters
    ----------
    a : float, optional

    Returns
    -------
    ImageNormalize

    """

    return ImageNormalize(stretch=AsinhStretch(a=a))
Beispiel #4
0
def plot(ra, dec, gal_list, s_list, ser_list, catalogue, search_size, filename):

    size = 240*search_size
    
    fitsurl = geturl(ra, dec, size=240*search_size, filters=texas_cfg['filters'], format="fits")
    fh = fits.open(fitsurl[0])
    wcs = WCS(fh[0].header)
    fim = fh[0].data
    fim[np.isnan(fim)] = 0.
    transform = AsinhStretch() + PercentileInterval(99)
    bfim = transform(fim)
    
    fig, ax = plt.subplots(1,1, figsize=(search_size*3,search_size*3))
    plt.subplot(projection=wcs)
    plt.imshow(bfim, cmap='summer')#, norm=LogNorm())
    ax = plt.gca()
    ax.set_xlim(0,240*search_size)
    ax.set_ylim(0,240*search_size)
    k = 1
    if len(gal_list)>0:
        last = gal_list[0]
    if catalogue == 'glade':
        for j in gal_list:
            if k==1 or ((abs(j['ra']-last['ra'])*np.cos(j['dec']/180*np.pi)>0.005 or abs(j['dec']-last['dec'])>0.005)):
                x, y = ((ra-j['ra'])*4*3600*np.cos(j['dec']/180*np.pi)+(size/2)), (j['dec']-dec)*4*3600+(size/2)
                ax.plot(x,y, 'bx', label = 'glade source')
#            print(j[10])
                z = int(float(j['z'])*10000)/10000.
                ax.annotate(str(k) + ': z='+str(z), xy=(x+20, y-5), fontsize=15, ha="center", color='b')
            k = k+1
            last = j
    elif catalogue == 'texas':
        plot_ellipse(gal_list, ra, dec, size, 'k', ax)#, label = 'texas source')

    for s in s_list:
        x, y = ((ra-s['raMean'])*4*3600*np.cos(s['decMean']/180*np.pi)+(size/2)), (s['decMean']-dec)*4*3600+(size/2)
        ax.plot(x,y, 'rx', label = 'PS point source')
	    
#    print(ser_list)
    if len(ser_list)>0:
        plot_ellipse(ser_list, ra, dec, size, 'm', ax)

    ax.plot([size/2+10, size/2+20], [size/2, size/2], 'k')
    ax.plot([size/2, size/2], [size/2+10, size/2+20], 'k')
    ax.plot([size/2-10, size/2-20], [size/2, size/2], 'k')
    ax.plot([size/2, size/2], [size/2-10, size/2-20], 'k')

    plt.savefig(filename)
Beispiel #5
0
    def plot_image(self, plotfile):
        """Plot an asinh-stretched image with the current apertures to
           standard graphics bitmap-format file.
        """
        
        # Normally the range 0.5 percent to 99.5 percent clips off the
        # outliers.
        pct_interval = AsymmetricPercentileInterval(0.50, 99.5)

        #sqrt_norm    = ImageNormalize(self._data, 
        #    interval=pct_interval, 
        #    stretch=SqrtStretch())
        asinh_norm   = ImageNormalize(self._data,
            interval=pct_interval,
            stretch=AsinhStretch())
    
        fig, ax = plt.subplots()
        ax.tick_params(axis='both', labelsize=8)

        im = ax.imshow(self._data, 
            origin='lower', 
            norm=asinh_norm)
        self._apertures.plot(color='red', lw=1.5, alpha=0.5)
         
        # Clean up file name string to prevent _ becoming subscripts.
        #fname_str = self._fitsimg.replace('_', '\_') # Not necessary?
        fname_str = self._fitsimg
        
        # Subtitle
        num_stars = len(self._phot_table)
        info_str  = f'{num_stars} stars'
        if self._nosatmask is not None:
            info_str += ' excluding saturated stars'
        else:
            info_str += ' including saturated stars'
        if self._max_adu is not None:
            info_str = 'Brightest ' + info_str
            
        ax.set_title(f'{fname_str}\n{info_str}', fontsize=8)
        ax.set_xlabel('X-axis (pixels)', fontsize=8)
        ax.set_ylabel('Y-axis (pixels)', fontsize=8)
        #plt.show()
        plt.savefig(self._plotfile,
            dpi=200, quality=95, optimize=True,
            bbox_inches='tight')
        self._logger.info(f'Plotting asinh-stretched bitmap of image and sources to {self._plotfile}')
        return
Beispiel #6
0
class TextureOutput(EventDispatcher):
    stretch_type = ConfigParserProperty(None,
                                        'view',
                                        'stretch',
                                        'app',
                                        val_type=StretchType)
    do_stretch = BooleanProperty(False)

    _stretch = {
        StretchType.LINEAR: LinearStretch(),
        StretchType.SQRT: SqrtStretch(),
        # StretchType.POWER: PowerStretch(),
        StretchType.ASIN: AsinhStretch()
    }

    def __init__(self):
        super(TextureOutput, self).__init__()
        self.register_event_type('on_update')
        self.texture = Texture.create(size=(320, 240))

    def write(self, buf):
        if self.do_stretch:
            buf = self.stretch(buf)
        self._blit(buf)

    @mainthread
    def _blit(self, buf):
        """
        write buffer to texture and fire event
        :param buf:
        :return:
        """
        self.texture.blit_buffer(buf, colorfmt='rgb', bufferfmt='ubyte')
        self.dispatch('on_update', buf)

    def on_update(self, buf):
        pass

    def stretch(self, buf):
        """
        perform the currently selected stretch method on buffer
        :param buf:
        :return:
        """
        im = np.frombuffer(buf, dtype=np.uint8)
        im = self._stretch[self.stretch_type](im / 255.0) * 255.0
        return im.astype(np.uint8).tobytes()
Beispiel #7
0
def make_image_subplot(fig,
                       image,
                       wcs,
                       title=None,
                       vval=None,
                       use_norm=True,
                       cmap='Greys',
                       use_projection=True):
    ''' Plotting tool for checking astrometric solution of ACAM '''
    figkwargs = {}
    kwargs = {}
    if use_projection:
        figkwargs['projection'] = wcs
    else:
        extent = [
            -image.shape[1] / 2. * wcs.wcs.cdelt[0] * 3600.,
            image.shape[1] / 2. * wcs.wcs.cdelt[0] * 3600.,
            -image.shape[0] / 2. * wcs.wcs.cdelt[1] * 3600.,
            image.shape[0] / 2. * wcs.wcs.cdelt[1] * 3600.
        ]
        kwargs['extent'] = extent
    fig.add_subplot(111, **figkwargs)
    kwargs['origin'] = 'lower'
    kwargs['interpolation'] = 'none'
    kwargs['cmap'] = cmap
    if use_norm:
        kwargs['norm'] = ImageNormalize(stretch=AsinhStretch())
    if vval is not None:
        kwargs['vmin'] = vval[0]
        kwargs['vmax'] = vval[1]
    else:
        vl = np.percentile(image, 3)
        vh = np.percentile(image, 99)
        ran = vh - vl
        vl = vl - 0.2 * ran
        vh = vh + 1.5 * ran
        kwargs['vmin'] = vl
        kwargs['vmax'] = vh
    plt.imshow(image, **kwargs)
    if title is not None:
        plt.title(title)
    ax = plt.gca()
    plt.xlim(ax.get_xlim())
    plt.ylim(ax.get_ylim())
    plt.xlabel('RA')
    plt.ylabel('Dec')
def plot_exposure(mask, wcs, image=None):
    plt.figure(figsize=(10, 10))
    ax = plt.subplot(projection=wcs)

    if image is not None:
        norm = ImageNormalize(image,
                              interval=ZScaleInterval(contrast=0.05),
                              stretch=AsinhStretch(a=0.2))
        ax.imshow(image, cmap='gray', norm=norm, interpolation='nearest')

    d = ax.imshow(mask, alpha=0.7)
    cb = plt.colorbar(d)
    cb.set_label('Live Time / Hours')
    ax.set_xlabel('Galactic Longitude')
    ax.set_ylabel('Galactic Latitude')

    return ax
Beispiel #9
0
    def __init__(self, data, header, **kwargs):
        super().__init__(data, header, **kwargs)

        # Fill in some missing info
        self.meta['detector'] = self.meta.get('detector', "AIA")
        if 'bunit' not in self.meta and 'pixlunit' in self.meta:
            # PIXLUNIT is not a FITS standard keyword
            self.meta['bunit'] = self.meta['pixlunit']
        self._nickname = self.detector
        self.plot_settings['cmap'] = self._get_cmap_name()
        self.plot_settings['norm'] = ImageNormalize(stretch=source_stretch(
            self.meta, AsinhStretch(0.01)),
                                                    clip=False)
        # DN is not a FITS standard unit, so convert to counts
        if self.meta.get('bunit', None) == 'DN':
            self.meta['bunit'] = 'ct'
        if self.meta.get('bunit', None) == 'DN/s':
            self.meta['bunit'] = 'ct/s'
Beispiel #10
0
def stretch_data(data, method="HistEqStretch"):
    """
    methods = 
    LogStretch,
    SqrtStretch,
    AsinhStretch,
    HistEqStretch
    """
    if method == "LogStretch":
        norm = ImageNormalize(stretch=LogStretch(data))
    elif method == "SqrtStretch":
        norm = ImageNormalize(stretch=SqrtStretch(data))
    elif method == "AsinhStretch":
        norm = ImageNormalize(stretch=AsinhStretch(data))
    elif method == "HistEqStretch":
        norm = ImageNormalize(stretch=HistEqStretch(data))
    else:
        norm = data
    return norm
Beispiel #11
0
def fits_to_png(ff, outfile, log=False):
    plt.clf()
    ax = plt.axes()
    fim = ff[1].data
    # replace NaN values with zero for display
    fim[np.isnan(fim)] = 0.0
    # set contrast to something reasonable
    transform = AsinhStretch() + PercentileInterval(99.5)
    bfim = transform(fim)
    ax.imshow(bfim, cmap="gray", origin="lower")
    circle = plt.Circle((np.shape(fim)[0] / 2 - 1, np.shape(fim)[1] / 2 - 1),
                        15,
                        color='r',
                        fill=False)
    ax.add_artist(circle)
    plt.gca().set_axis_off()
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
    plt.margins(0, 0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.savefig(outfile, bbox_inches='tight', pad_inches=0)
    def __init__(self, preload=8, limit=65536):
        self.limit = limit
        self.tiles_per_raw_image = 256
        self.nfiles_considered = self.limit // self.tiles_per_raw_image
        self.preload = preload
        self.s3 = boto3.client('s3')
        self.zmax = ZMaxInterval()
        self.zscale = ZScaleInterval()
        self.logstretch = LogStretch()
        self.asinhstretch = AsinhStretch()

        filenames = []
        with open(
                os.path.join(os.path.dirname(__file__),
                             'astroquery-index.csv'), 'r') as index:
            while True:
                line = index.readline()
                if line == "":
                    break

                line = line.rstrip('\n')
                cols = line.split(',')
                s3uri = cols[1]
                s3key = s3uri.replace('s3://stpubdata/', '')
                filenames.append(s3key)

        print("{0} file URLs loaded".format(len(filenames)))
        print("Randomising and keeping {0} of them".format(
            self.nfiles_considered))
        random.shuffle(filenames)
        filenames = filenames[0:self.nfiles_considered]

        self.files = numpy.array(filenames)

        self.tile_buffer = []
        self.buffer_condition = threading.Condition()
        self.feeder = threading.Thread(target=self.feed_loop)

        self.done = False
        self.feeder.start()
Beispiel #13
0
    def __init__(self, data, header, **kwargs):
        GenericMap.__init__(self, data, header, **kwargs)

        # Fill in some missing info
        self.meta['detector'] = "AIA"
        self._nickname = self.detector
        self.plot_settings['cmap'] = plt.get_cmap(self._get_cmap_name())
        self.plot_settings['norm'] = ImageNormalize(stretch=source_stretch(self.meta, AsinhStretch(0.01)))
Beispiel #14
0
def asin_stretch_norm(images: Stamp):
    return ImageNormalize(
        images,
        interval=MinMaxInterval(),
        stretch=AsinhStretch(),
    )
Beispiel #15
0
    table = table[numpy.argsort(flist)]
    print(table)
    if color:
        if len(table) > 3:
            # pick 3 filters
            table = table[[0, len(table) // 2, len(table) - 1]]
        for i, param in enumerate(["red", "green", "blue"]):
            url = url + "&{}={}".format(param, table['filename'][i])
    else:
        urlbase = url + "&red="
        url = []
        for filename in table['filename']:
            url.append(urlbase + filename)
    return url

from astropy.io import fits
from astropy.visualization import PercentileInterval, AsinhStretch

fitsurl = geturl(ra, dec, size=size, filters="i", format="fits")
print(fitsurl,"URL")
fh = fits.open(fitsurl[0])
fim = fh[0].data
# replace NaN values with zero for display
fim[numpy.isnan(fim)] = 0.0
# set contrast to something reasonable
transform = AsinhStretch() + PercentileInterval(99.5)
bfim = transform(fim)

pylab.subplot(111)
pylab.title('Crab Nebula PS1 i (fits)')
pylab.imshow(bfim,cmap="gray",origin="lower")
Beispiel #16
0
def extract_noise_and_exp_patches(candels_image, vela_noise_less_image,
                                  make_plots, is_image_cube, plots_loc,
                                  fits_save_loc):
    '''
    This is the main function that creates a noised image of the input simulated image based on
    the real (HST, candels image). The user has the freedom to store plots or not.
    :param candels_image: Input real postage stamp from which the noise properties will be extracted.
    Note units should be in electrons/sec
    :param vela_noise_less_image: An un-noised image of a simulated observation (note, need to be PSF convolved).
    The units of this image should be in micro-janskies per square arcseconds. G Snyder et al.,
    :param make_plots: To make and store the diagnostic plots on the run, please set it True, otherwise to False.
    :param is_image_cube: Provide True if the real image is a image cube (as in my case), elif it is a simple postage
    stamp of a galaxy then provide False.
    :param plots_loc: save location of the generated plots
    :param fits_save_loc: save location of the generated noise-added fits files.
    :return: Nothing.
    '''
    if is_image_cube == False:
        original_data = fits.getdata(
            candels_image)  # get the data from the real postage stamp.
    else:
        can_image_cube = fits.open(candels_image)  # open the image
        original_data = can_image_cube[
            1].data  # load the data from the second HDU, which has the original image.  #<ME> #OCODE [1]

    vela_noiseless_data = fits.getdata(
        vela_noise_less_image)  # get data of the noise-less simulated image
    vela_header = fits.getheader(
        vela_noise_less_image
    )  # load the header information from noise-less image.
    # Identify objects and generate the segmentation map of the sources in the real image.
    # the set of parameters i used are discussed in Mantha+19 (submitted), which i converged based on the
    # official numbers used in Guo et al., 2013 and Galametz et al., 2013. However, some values are modified
    # to aid the detection of small and faint sources in the images and make sure that the source segmentation
    # maps extend out to tails of the light distribution. Otherwise, the source light will bleed into the noise stamps.
    obj, segmap = identify_objects(original_data, 0.75, 7, 64, 0.001, {
        'sep_filter_kwarg': 'tophat',
        'sep_filter_size': '5'
    })  #(I!!)'tophat','5'

    sky_mask = deepcopy(
        segmap)  # make a deep copy of the segmap for further manipulation.
    sky_mask = sky_mask + 1  # add 1 to all seg values.
    sky_mask[
        sky_mask > 1] = 0  # make everything that is larger than 1 to zeros.
    # This step makes all the sources to 0 and the 'sky'

    noise_stamp = make_noise_patch(sky_mask * original_data, (700, 700))
    # Generate a large noise mosaic of size 700x700

    # based on the size of the simulated image, this step will query a noise stamp matching its size.
    noise_stamp_size_matched = image.extract_patches_2d(
        noise_stamp, (vela_header['NAXIS1'], vela_header['NAXIS2']),
        max_patches=1)

    if make_plots:  # If the make plots is True, then do the following steps...
        fig, axs = plt.subplots(
            2, 2, figsize=[20,
                           16])  # Create a four panel layout 2 rows x 2 cols..
        [[ax1, ax2], [ax3, ax4]] = axs  # Axes

        norm = ImageNormalize(original_data,
                              interval=PercentileInterval(98.),
                              stretch=AsinhStretch())
        # find the image normalization for visualization matching DS9 settings. 98 percentile, Asinh stretch.

        ax1.imshow(original_data,
                   origin='lower',
                   cmap='gray',
                   alpha=1,
                   norm=norm)  # visualize the real image
        for i in range(
                len(obj)
        ):  # this routine uses the source identified objects to put elliptical regions.
            e = Ellipse(xy=(obj['x'][i], obj['y'][i]),
                        width=6 * obj['a'][i],
                        height=6 * obj['b'][i],
                        angle=obj['theta'][i] * 180. / np.pi)
            e.set_facecolor('none')
            e.set_edgecolor('red')
            ax1.add_artist(e)
            ax1.text(obj['x'][i],
                     obj['y'][i],
                     '%s' % (get_segmap_value(segmap,
                                              (obj['x'][i], obj['y'][i]))),
                     fontsize=10,
                     color='black',
                     ha='center',
                     va='center')  # also put seg value at the center.

        # visualize the source masked image.
        ax2.imshow(sky_mask * original_data,
                   origin='lower',
                   interpolation='nearest',
                   norm=norm,
                   cmap='gray')

        # visualizing the histogram of pixel values in the source masked image.
        ax4.hist([i for i in (sky_mask * original_data).flat if i != 0],
                 bins=100,
                 normed=True,
                 color='red',
                 histtype='step')  #<ME> # normed -> density

        # visualizing the large noise mosaic
        ax3.imshow(noise_stamp_size_matched[0],
                   origin='lower',
                   interpolation='nearest',
                   norm=norm,
                   cmap='gray')

        # Sanity check, overplot another histogram of the large noise stamp.. to compare the two histograms.
        ax4.hist(noise_stamp.flat,
                 bins=100,
                 normed=True,
                 color='blue',
                 histtype='step')  #<ME> # normed ->

        # draw the mean of the histogram.
        ax4.axvline(
            x=np.mean(noise_stamp.flat),
            linestyle='--',
            color='green',
            label='sky mean [e/s] = %s\n sky mean [muj/arcsec^2] = %s' %
            (round(np.mean(noise_stamp.flat), 4),
             round((1E-7 * np.mean(noise_stamp.flat) * (10**6)) /
                   (0.06 * 0.06), 4)))
        # ax4.imshow(noise_stamp,origin='lower', interpolation='nearest',norm=norm,cmap='gray')

        make_axis_labels_off(np.array(
            [ax1, ax2, ax3]))  # switch of the axes that show images.
        plt.subplots_adjust(
            wspace=0.01, hspace=0.01
        )  # Adjust the spacing between subplots to avoid whitespace.
        stamp_gen_image_name = os.path.basename(candels_image).strip(
            '.fits')  # figuring out the name for later use.
        ax4.legend(loc=2)  # show legend for axis4 (histogram).
        # save figure... CHANGE THE PATHS...
        plt.savefig('%s/%s.png' % (plots_loc, stamp_gen_image_name),
                    bbox_inches='tight')
        plt.close(fig)

    exp_time = 3300.0  # Assuming an exposure time of 3300 seconds based on the table in Koekemoer+11 for ~2-orbit depth.  #<ME> whaat

    # Note that the real image from CANDELS are in units of e/s. It is important to convert into electrons.
    real_noise_stamp_counts = noise_stamp_size_matched[
        0] * exp_time  # converting to counts.

    pht_nu = float(
        vela_header['PHOTFNU']
    ) * 1E6  # PHOTFNU is in units of J * sec/electron; 1E6 is to convert to uJy.

    #Converting the noise-less vela image into units of electrons.
    # uJy/arcsec^2 * (arcsec^2)/ (uJy * seconds/electrons) * seconds = electrons
    noise_less_vela_electrons = (vela_noiseless_data *
                                 (0.06**2) / pht_nu) * exp_time
    if np.any(
            noise_less_vela_electrons < 0
    ):  # at some very early redshift snapshots.. there are some negative values.
        # to avoid the poisson routine from crashing, i made them zeros.
        noise_less_vela_electrons[np.where(
            noise_less_vela_electrons < 0)] = 0.0

    # Poisson realization of the noise-less simulated image.
    poisson_noised_vela = np.random.poisson(noise_less_vela_electrons,
                                            size=None)

    # Poisson residual. We want to make sure that the source Poisson noise is correlated spatially
    # as the background sky (owing to drizzling etc..)
    poisson_resid = poisson_noised_vela - noise_less_vela_electrons

    # Defining a smoothing kernel to smooth the poisson residual image such that the
    # 1D PSD of the resultant smoothed poisson residual matches the noise PSD.
    # Upon some extensive experimentation, i found that a Gaussian kernel of sigma=0.6
    # works the best (by eye). I feel that there might be better automated (sort of regression of MCMC) way
    # to figure this out on an image-by-image basis, but for now i hardcoded it.
    smth_kernel = Gaussian2DKernel(0.6)
    smoothed_poisson_resid = convolve(
        poisson_resid,
        kernel=smth_kernel)  # Smooth the poisson to induce correlation.

    correlated_poisson = smoothed_poisson_resid + noise_less_vela_electrons  # add this to noise-less image (still in counts)
    # Full noise-added image (correlated poisson realization + real noise stamp)
    full_noise_added = correlated_poisson + real_noise_stamp_counts

    # unit conversion back to uJy/arcsec^2
    full_noise_added_ujy_arcsec2 = (full_noise_added /
                                    exp_time) * pht_nu / (0.06**2)

    # Get the VELA ID -- this is a specific thing, whoever is using this code may change it as needed.
    what_vela_id = os.path.basename(vela_noise_less_image).strip(
        '_SB00.fits').split('_')[0]

    # writing additional information to the existing header. Preserving original header info by Snyder et al.,
    # with extra information for back tracking purposes.
    vela_header['expsr_t'] = '%s' % exp_time

    # creating a filename to store the output images.
    output_noise_filename = os.path.basename(vela_noise_less_image).strip(
        '_SB00.fits') + '_real_noise_stamp'
    vela_header[
        'REAL_N'] = '%s' % output_noise_filename  # the noise stamp name used in this run, printed to header
    fits.writeto(
        '%s/%s/%s.fits' % (fits_save_loc, what_vela_id, output_noise_filename),
        noise_stamp_size_matched[0] * pht_nu / (0.06**2),
        overwrite=True
    )  # saving the noise postage stamp in units of ujy/arcsec^2 (same as simulated images).

    output_poiss_filename = os.path.basename(vela_noise_less_image).strip(
        '_SB00.fits') + '_smoothed_poiss'
    vela_header['POISS'] = '%s' % output_poiss_filename
    fits.writeto(
        '%s/%s/%s.fits' % (fits_save_loc, what_vela_id, output_poiss_filename),
        (correlated_poisson / exp_time) * pht_nu / (0.06**2),
        overwrite=True
    )  # saving the correlated poisson realization in units of uJy/arcsec^2.

    output_noise_added_filename = os.path.basename(
        vela_noise_less_image).strip('_SB00.fits') + '_noise_added'
    fits.writeto(
        '%s/%s/%s.fits' %
        (fits_save_loc, what_vela_id, output_noise_added_filename),
        full_noise_added_ujy_arcsec2,
        overwrite=True,
        header=vela_header
    )  # saving the final noise-added image in units of uJy/arcsec^2
    if make_plots:  # if the make plots is set, then generate this large figure that showcases all steps.
        fig, axs = plt.subplots(2, 3, figsize=[18,
                                               8])  # 2 rows x 3 cols figure.
        [[ax1, ax2, ax3], [ax4, ax5, ax6]] = axs
        '''Normalizing all the image visualizations to the same setting...'''
        norm_nl = ImageNormalize(noise_less_vela_electrons,
                                 interval=PercentileInterval(98),
                                 stretch=AsinhStretch())
        norm_poiss = ImageNormalize(poisson_noised_vela,
                                    interval=PercentileInterval(98),
                                    stretch=AsinhStretch())
        norm_resid = ImageNormalize(poisson_resid,
                                    interval=PercentileInterval(98),
                                    stretch=AsinhStretch())

        norm_smth_resid = ImageNormalize(smoothed_poisson_resid,
                                         interval=PercentileInterval(98),
                                         stretch=AsinhStretch())
        norm_corr_poiss = ImageNormalize(correlated_poisson,
                                         interval=PercentileInterval(98),
                                         stretch=AsinhStretch())

        norm_full_noised = ImageNormalize(full_noise_added_ujy_arcsec2,
                                          interval=PercentileInterval(98.5),
                                          stretch=AsinhStretch())
        ''' Showing all the images in their respective axes...'''
        ax1.imshow(noise_less_vela_electrons,
                   origin='lower',
                   cmap='gray',
                   alpha=1,
                   norm=norm_nl)
        ax2.imshow(poisson_noised_vela,
                   origin='lower',
                   cmap='gray',
                   alpha=1,
                   norm=norm_poiss)
        ax3.imshow(poisson_resid,
                   origin='lower',
                   cmap='gray',
                   alpha=1,
                   norm=norm_resid)
        ax4.imshow(smoothed_poisson_resid,
                   origin='lower',
                   cmap='gray',
                   alpha=1,
                   norm=norm_smth_resid)
        ax5.imshow(correlated_poisson,
                   origin='lower',
                   cmap='gray',
                   alpha=1,
                   norm=norm_corr_poiss)
        ax6.imshow(full_noise_added_ujy_arcsec2,
                   origin='lower',
                   cmap='gray',
                   alpha=1,
                   norm=norm_full_noised)

        make_axis_labels_off(
            axs)  # switch of all axes labels.. as all are images.
        plt.tight_layout(
        )  # tight layout... for effective (somewhat) professional visualization.
        plt.subplots_adjust(
            wspace=0.0,
            hspace=0.01)  # adjust the white space between rows and cols.

        fig_name = os.path.basename(vela_noise_less_image).strip(
            '_SB00.fits') + '_noise_addition'
        plt.savefig('%s/%s.png' % (plots_loc, fig_name),
                    bbox_inches='tight')  #save as png.
        plt.close(fig)  # close figure object to avoid clutter in memory.
    if make_plots:
        '''If make plots is set... then this block visualizes the 1D PSD of real noise stamp and
        the correlated poisson image.'''
        fig2 = plt.figure(figsize=[10, 8])
        ax3 = fig2.gca()
        psd1d_real_noise = do_autocorr_power_spectrum(
            real_noise_stamp_counts)  #1D PSD of real noise stamp
        psd1d_un_correlated_poiss = do_autocorr_power_spectrum(
            poisson_resid)  # 1D PSD of uncorrelated poisson (should be flat)
        psd1d_smoothed_poiss = do_autocorr_power_spectrum(
            smoothed_poisson_resid)  # 1D PSD of correlated poisson image.
        '''The following three lines show the 1D PSDs'''
        ax3.semilogy(psd1d_real_noise,
                     color='red',
                     linestyle='-',
                     label='Real noise')
        ax3.semilogy(psd1d_un_correlated_poiss,
                     color='red',
                     linestyle='--',
                     label='Poisson Residual')
        ax3.semilogy(psd1d_smoothed_poiss,
                     color='blue',
                     linestyle='-.',
                     label='Smoothed Poisson Residual')

        ax3.set_xlabel('Spatial Frequency', fontsize=18)  # xaxis labels
        ax3.set_ylabel('Normalized AutoCorr Power Spectrum',
                       fontsize=18)  # yaxis label.

        ax3.legend(loc=1, fontsize=16)  #legend..
        fig_name = os.path.basename(vela_noise_less_image).strip(
            '_SB00.fits') + '_autocorr'
        plt.savefig('%s/%s.png' % (plots_loc, fig_name),
                    bbox_inches='tight')  # save the plot as a png.
        plt.close(fig2)
    return
Beispiel #17
0
#
# As with `~aiapy.psf.psf`, this will be much faster if you have
# a GPU and `cupy` installed.
#
#

# %%
m_deconvolved = aiapy.psf.deconvolve(m, psf=psf)

# %% [markdown]
# Let's compare the convolved and deconvolved images.
#
#

# %%
norm = ImageNormalize(vmin=0, vmax=1.5e4, stretch=AsinhStretch(0.01))
fig = plt.figure()
ax = fig.add_subplot(121, projection=m)
m.plot(axes=ax, norm=norm)
ax = fig.add_subplot(122, projection=m_deconvolved)
m_deconvolved.plot(axes=ax, annotate=False, norm=norm)
ax.coords[0].set_axislabel(' ')
ax.coords[1].set_axislabel(' ')
ax.coords[1].set_ticklabel_visible(False)
plt.show()

# %% [markdown]
# The differences become a bit more obvious when we zoom in. Note that the
# deconvolution has the effect of "deblurring" the image.
#
#
Beispiel #18
0
def run_barshadow_tests(plfile,
                        bsfile,
                        barshadow_threshold_diff=0.05,
                        save_final_figs=False,
                        show_final_figs=False,
                        save_intermediary_figs=False,
                        show_intermediary_figs=False,
                        write_barshadow_files=False,
                        debug=False):
    """

    Args:
        plfile: string, 2D spectra output prior to the bar shadow step (e.g., extract_2d or pathloss product)
        bsfile: string, read in 2D spectra output from the bar shadow step
        barshadow_threshold_diff: float, this value comes from the document ESA-JWST-SCI-NRS-TN-2016-016.pdf, it is
                                    an arbitrary error of the reference file of 0.0025 absolute error or 5% relative
                                    error (no justification provided)
        save_final_figs: boolean, if True the final figures with corresponding histograms will be saved
        show_final_figs: boolean, if True the final figures with corresponding histograms will be shown
        save_intermediary_figs: boolean, if True the intermediary figures with corresponding histograms will be saved
        show_intermediary_figs: boolean, if True the intermediary figures with corresponding histograms will be shown
        debug: boolean

    Returns:

    """

    # start the list of messages that will be added to the log file
    log_msgs = []

    # start the timer
    barshadow_test_start_time = time.time()

    # read in 2D spectra output prior to the bar shadow step (e.g., extract_2d or pathloss product)
    print(
        'Checking if files exist and obtaining datamodels, this takes a few minutes...'
    )
    if os.path.isfile(plfile):
        if debug:
            print('Extract_2d file does exist.')
    else:
        result_msg = 'Extract_2d file does NOT exist. Barshadow test will be skipped.'
        log_msgs.append(result_msg)
        result = 'skip'
        return result, result_msg, log_msgs

    # get the data model
    pl = datamodels.open(plfile)
    if debug:
        print('got extract_2d datamodel!')

    # read in 2D spectra output from the bar shadow step
    if os.path.isfile(bsfile):
        if debug:
            print('Bar shadow file does exist.')
    else:
        result_msg = 'Barshadow file does NOT exist. Barshadow test will be skipped.'
        log_msgs.append(result_msg)
        result = 'skip'
        return result, result_msg, log_msgs

    bs = datamodels.open(bsfile)
    if debug:
        print('got barshadow datamodel!')

    # list to determine if pytest is passed or not
    total_test_result = OrderedDict()

    if write_barshadow_files:
        # create the fits list to hold the image of the correction values
        hdu0 = fits.PrimaryHDU()
        outfile = fits.HDUList()
        outfile.append(hdu0)

        # create the fits list to hold the image of the comparison values
        hdu0 = fits.PrimaryHDU()
        complfile = fits.HDUList()
        complfile.append(hdu0)

    # loop over the slitlets in both files
    print('Looping over open slitlets...')
    for plslit, bsslit in zip(pl.slits, bs.slits):
        # check that slitlet name of the data from the pathloss or extract_2d and the barshadow datamodels are the same
        slit_id = bsslit.name
        print('Working with slitlet ', slit_id)
        if plslit.name == bsslit.name:
            msg = 'Slitlet name in fits file previous to barshadow and in barshadow output file are the same.'
            log_msgs.append(msg)
            print(msg)
        else:
            msg = '* Missmatch of slitlet names in fits file previous to barshadow and in barshadow output file. Skipping test.'
            result = 'skip'
            log_msgs.append(msg)
            return result, msg, log_msgs

        # obtain the data from the pathloss or extract_2d and the barshadow datamodels
        plsci = plslit.data
        bssci = bsslit.data

        if debug:
            print('plotting the data for both input files...')

        # set up generals for all the plots
        font = {  # 'family' : 'normal',
            'weight': 'normal',
            'size': 16
        }
        matplotlib.rc('font', **font)

        plt.figure(figsize=(12, 10))
        # Top figure
        plt.subplot(211)
        norm = ImageNormalize(plsci,
                              vmin=0.,
                              vmax=500.,
                              stretch=AsinhStretch())
        plt.imshow(plsci,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.title(
            'Normalized science data before barshadow step for slitlet ' +
            slit_id)
        # Bottom figure
        plt.subplot(212)
        norm = ImageNormalize(bssci,
                              vmin=0.,
                              vmax=500.,
                              stretch=AsinhStretch())
        plt.imshow(bssci,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.title('Normalized barshadow science data for slitlet ' + slit_id)
        # Show and/or save figures
        file_path = bsfile.replace(os.path.basename(bsfile), "")
        file_basename = os.path.basename(bsfile.replace("_barshadow.fits", ""))
        if save_intermediary_figs:
            t = (file_basename,
                 "Barshadowtest_NormSciData_slitlet" + slit_id + ".pdf")
            plt_name = "_".join(t)
            plt_name = os.path.join(file_path, plt_name)
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        if show_intermediary_figs:
            plt.show()
        plt.close()

        # calculate spatial profiles for both products
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(9, 9))
        plt.subplots_adjust(hspace=0.5)
        fig.subplots_adjust(wspace=0.6)
        point1 = [355, 375]
        plprof1 = np.median(plsci[:, point1[0]:point1[1]], 1)
        bsprof1 = np.median(bssci[:, point1[0]:point1[1]], 1)
        # only use pixels that are not NaN
        x1 = np.squeeze(np.nonzero(~np.isnan(bsprof1)))
        ax1.plot(x1, plprof1[x1])
        ax1.set_title('Before barshadow array slice 1')
        ax1.set_xlabel('x (pixels)')
        ax1.set_ylabel('y (pixels)')
        if debug:
            print('ax1 std_dev/mean = ',
                  np.nanstd(plprof1[x1]) / np.nanmean(plprof1[x1]))

        point2 = [1190, 1210]
        plprof2 = np.median(plsci[:, point2[0]:point2[1]], 1)
        bsprof2 = np.median(bssci[:, point2[0]:point2[1]], 1)
        x2 = np.squeeze(np.nonzero(~np.isnan(bsprof2)))
        ax2.plot(x2, plprof2[x2])
        ax2.set_title('Before barshadow array slice 2')
        ax2.set_xlabel('x (pixels)')
        ax2.set_ylabel('y (pixels)')
        if debug:
            print('ax2 std_dev/mean = ',
                  np.nanstd(plprof2[x2]) / np.nanmean(plprof2[x2]))

        ax3.plot(x1, bsprof1[x1])
        ax3.set_title('Barshadow array slice 1')
        ax3.set_xlabel('x (pixels)')
        ax3.set_ylabel('y (pixels)')
        if debug:
            print('ax3 std_dev/mean = ',
                  np.nanstd(bsprof1) / np.nanmean(bsprof1[x1]))

        ax4.plot(x2, bsprof2[x2])
        if debug:
            print('ax4 std_dev/mean = ',
                  np.nanstd(bsprof2) / np.nanmean(bsprof2[x2]))
        ax4.set_title('Barshadow array slice 2')
        ax4.set_xlabel('x (pixels)')
        ax4.set_ylabel('y (pixels)')

        fig.suptitle('Spatial profiles before correction for slitlet ' +
                     slit_id,
                     fontsize=20)

        # Show and/or save figures
        if save_intermediary_figs:
            t = (file_basename,
                 "Barshadowtest_SpatialProfilesBe4correction_slitlet" +
                 slit_id + ".pdf")
            plt_name = "_".join(t)
            plt_name = os.path.join(file_path, plt_name)
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        if show_intermediary_figs:
            plt.show()
        plt.close()

        ### compare pipeline correction values with independent calculation

        # get the bar shadow corrections from the step product
        bscor_pipe = bsslit.barshadow

        # get correction from independent calculation
        msg = 'Calculating barshadow correction...'
        log_msgs.append(msg)
        print(msg)
        # Create x, y indices using the Trace WCS
        x, y = wcstools.grid_from_bounding_box(bsslit.meta.wcs.bounding_box,
                                               step=(1, 1))
        if debug:
            print('x = ', x)

        # derive the slity_y values per pixel
        wcsobj = bsslit.meta.wcs
        det2slit = wcsobj.get_transform('detector', 'slit_frame')
        bsslitx, bsslity, bswave = det2slit(x, y)
        # scale the slit_y values by 1.15 to take into account the shutter pitch
        bsslity = bsslity / 1.15

        # compute bar shadow corrections independently, given the wavelength and slit_y from the data model
        # get the reference file (need the mos1x1 for this internal lamp case, where each shutter was extracted separately)
        #if bsslit.shutter_state == 'x':
        ref_file = '/grp/jwst/wit4/nirspec/CDP3/05_Other_Calibrations/5.3_BarShadow/referenceFilesBS-20160401/jwst-nirspec-mos1x1.bsrf.fits'
        if bsslit.shutter_state == '1':
            ref_file = '/grp/jwst/wit4/nirspec/CDP3/05_Other_Calibrations/5.3_BarShadow/referenceFilesBS-20160401/jwst-nirspec-mos1x3.bsrf.fits'
        if debug:
            '''    shutter_state : str ----- ``Slit.shutter_state`` attribute - a combination of
                                    possible values: ``1`` - open shutter, ``0`` - closed shutter, ``x`` - main shutter
            '''
            print('slit.shutter_state = ', bsslit.shutter_state)
        msg = 'Reference file used for barshadow calculation: ' + ref_file
        log_msgs.append(msg)
        print(msg)
        hdul = fits.open(ref_file)
        bscor_ref = hdul[1].data
        w = wcs.WCS(hdul[1].header)
        y1, x1 = np.mgrid[:bscor_ref.shape[0], :bscor_ref.shape[1]]
        lam_ref, slity_ref = w.all_pix2world(x1, y1, 0)

        # for slit wcs, interpolate over the reference file values
        lam_ref = lam_ref.reshape(bscor_ref.size)
        slity_ref = slity_ref.reshape(bscor_ref.size)
        pixels_ref = np.column_stack((lam_ref, slity_ref))
        bscor_ref = bscor_ref.reshape(bscor_ref.size)
        bswave_ex = bswave.reshape(bswave.size)
        indxs = ~np.isnan(bswave_ex)
        bsslity_ex = bsslity.reshape(bsslity.size)
        xyints = np.column_stack((bswave_ex[indxs], bsslity_ex[indxs]))
        bscor = np.empty(bswave_ex.size)
        bscor[:] = np.nan
        bscor[indxs] = griddata(pixels_ref, bscor_ref, xyints, method='linear')
        bscor = bscor.reshape(bswave.shape[0], bswave.shape[1])
        if debug:
            print('bscor.shape = ', bscor.shape)
        msg = 'Calculation of barshadow correction done.'
        log_msgs.append(msg)
        print(msg)

        shutter_status = bsslit.shutter_state
        if bsslit.shutter_state == 'x':
            fi = shutter_status.find('x')
        if bsslit.shutter_state == '1':
            fi = shutter_status.find('1')
        if debug:
            print('fi = ', fi)
        nax2 = hdul[1].header['NAXIS2']
        cv1 = hdul[1].header['CRVAL1']
        cd1 = hdul[1].header['CDELT1']
        cd2 = hdul[1].header['CDELT2']
        shutter_height = 1. / cd2
        fi2 = nax2 - shutter_height * (1 + fi)
        if debug:
            print('nax2, fi2, shutter_height:', nax2, fi2, shutter_height)
        yrow = fi2 + bsslity * shutter_height
        wcol = (bswave - cv1) / cd1
        #print(yrow[9,1037],wcol[9,1037])
        if debug:
            print('np.shape(yrow)=', np.shape(yrow))
        point3 = [10, np.shape(yrow)[1] - 50]
        print(yrow[point3[0], point3[1]], wcol[point3[0], point3[1]])

        fig = plt.figure(figsize=(12, 10))
        # Top figure
        plt.subplot(211)
        plt.imshow(bscor,
                   vmin=0.,
                   vmax=1.,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.title('Calculated Correction')
        plt.colorbar()
        # Bottom figure
        plt.subplot(212)
        plt.imshow(bscor_pipe,
                   vmin=0.,
                   vmax=1.,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.title('Pipeline Correction')
        plt.colorbar()

        fig.suptitle('Barshadow correction comparison for slitlet ' + slit_id,
                     fontsize=20)

        # Show and/or save figures
        if save_intermediary_figs:
            t = (file_basename, "Barshadowtest_CorrectionComparison_slitlet" +
                 slit_id + ".pdf")
            plt_name = "_".join(t)
            plt_name = os.path.join(file_path, plt_name)
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        if show_intermediary_figs:
            plt.show()
        plt.close()

        if debug:
            #print('bscor_pipe[9,1037],bswave[9,1037],bsslity[9,1037],bscor[9,1037]: ',
            #      bscor_pipe[9,1037],bswave[9,1037],bsslity[9,1037],bscor[9,1037])
            print(
                'bscor_pipe[point3[0], point3[1]],bswave[point3[0], point3[1]],bsslity[point3[0], point3[1]],bscor[point3[0], point3[1]]: ',
                bscor_pipe[point3[0], point3[1]], bswave[point3[0], point3[1]],
                bsslity[point3[0], point3[1]], bscor[point3[0], point3[1]])

        print('Creating final barshadow test plot...')
        reldiff = (bscor_pipe - bscor) / bscor
        if debug:
            print('np.nanmean(reldiff),np.nanstd(reldiff) : ',
                  np.nanmean(reldiff), np.nanstd(reldiff))
        fig = plt.figure(figsize=(12, 10))
        # Top figure - 2D plot
        plt.subplot(211)
        plt.imshow(reldiff,
                   vmin=-0.01,
                   vmax=0.01,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.colorbar()
        plt.title('Relative differences')
        plt.xlabel('x (pixels)')
        plt.ylabel('y (pixels)')
        # Bottom figure - histogram
        ax = plt.subplot(212)
        plt.hist(reldiff[~np.isnan(reldiff)], bins=100, range=(-0.1, 0.1))
        plt.xlabel(
            '(Pipeline_correction - Calculated_correction) / Calculated_correction'
        )
        plt.ylabel('N')
        # add vertical line at mean and median
        nanind = np.isnan(reldiff)  # get all the nan indexes
        notnan = ~nanind  # get all the not-nan indexes
        arr_mean = np.mean(reldiff[notnan])
        arr_median = np.median(reldiff[notnan])
        arr_stddev = np.std(reldiff[notnan])
        plt.axvline(arr_mean, label="mean = %0.3e" % (arr_mean), color="g")
        plt.axvline(arr_median,
                    label="median = %0.3e" % (arr_median),
                    linestyle="-.",
                    color="b")
        str_arr_stddev = "stddev = {:0.3e}".format(arr_stddev)
        ax.text(0.73,
                0.67,
                str_arr_stddev,
                transform=ax.transAxes,
                fontsize=16)
        plt.legend()
        plt.minorticks_on()

        fig.suptitle('Barshadow correction relative differences for slitlet ' +
                     slit_id,
                     fontsize=20)

        # Show and/or save figures
        if save_final_figs:
            t = (file_basename,
                 "Barshadowtest_RelDifferences_slitlet" + slit_id + ".pdf")
            plt_name = "_".join(t)
            plt_name = os.path.join(file_path, plt_name)
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        if show_final_figs:
            plt.show()
        plt.close()

        # Determine if median test is passed
        slitlet_test_result_list = []
        tested_quantity = 'barshadow_correction'
        stats = auxfunc.print_stats(reldiff[notnan],
                                    tested_quantity,
                                    barshadow_threshold_diff,
                                    abs=False,
                                    return_percentages=True)
        _, stats_print_strings, percentages = stats
        result = auxfunc.does_median_pass_tes(arr_median,
                                              barshadow_threshold_diff)
        slitlet_test_result_list.append({tested_quantity: result})
        for line in stats_print_strings:
            log_msgs.append(line)
        msg = " * Result of median test for slit " + slit_id + ": " + result + "\n"
        print(msg)
        log_msgs.append(msg)

        tested_quantity = "percentage_greater_3threshold"
        result = auxfunc.does_median_pass_tes(percentages[1], 10)
        slitlet_test_result_list.append({tested_quantity: result})
        msg = " * Result of number of points greater than 3*threshold greater than 10%: " + result + "\n"
        print(msg)
        log_msgs.append(msg)

        tested_quantity = "percentage_greater_5threshold"
        result = auxfunc.does_median_pass_tes(percentages[2], 10)
        slitlet_test_result_list.append({tested_quantity: result})
        msg = " * Result of number of points greater than 5*threshold greater than 10%: " + result + "\n"
        print(msg)
        log_msgs.append(msg)

        # Make plots of normalized corrected data
        corrected = plsci / bscor
        plt.figure(figsize=(12, 10))
        norm = ImageNormalize(corrected,
                              vmin=0.,
                              vmax=500.,
                              stretch=AsinhStretch())
        plt.imshow(corrected,
                   norm=norm,
                   aspect=10.0,
                   origin='lower',
                   cmap='viridis')
        plt.title(
            'Normalized data before barshadow step with correction applied')
        plt.xlabel(
            'Sci_data_before_barshadow / barshadow_calculated_correction')
        plt.ylabel('Normalized data')
        # Show and/or save figures
        if save_intermediary_figs:
            t = (file_basename,
                 "Barshadowtest_CorrectedData_slitlet" + slit_id + ".pdf")
            plt_name = "_".join(t)
            plt_name = os.path.join(file_path, plt_name)
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        if show_intermediary_figs:
            plt.show()
        plt.close()

        # calculate spatial profiles for both products
        fig, ((ax1, ax2)) = plt.subplots(1, 2, figsize=(19, 9))
        prof = np.median(corrected[:, point1[0]:point1[1]], 1)
        x = np.arange(corrected.shape[0])
        ax1.plot(x, prof)
        ax1.set_title('Before barshadow array slice 1')
        ax1.set_xlabel('x (pixels)')
        ax1.set_ylabel('y (pixels)')
        if debug:
            print('np.nanstd(prof)/np.nanmean(prof) = ',
                  np.nanstd(prof) / np.nanmean(prof))
        prof = np.median(corrected[:, point2[0]:point2[1]], 1)
        x = np.arange(corrected.shape[0])
        ax2.plot(x, prof)
        ax2.set_title('Before barshadow array slice 2')
        ax2.set_xlabel('x (pixels)')
        ax2.set_ylabel('y (pixels)')
        if debug:
            print('np.nanstd(prof)/np.nanmean(prof) = ',
                  np.nanstd(prof) / np.nanmean(prof))
        fig.suptitle('Corrected spatial profiles for slitlet ' + slit_id,
                     fontsize=20)
        # Show and/or save figures
        if save_intermediary_figs:
            t = (file_basename,
                 "Barshadowtest_CorrectedSpatialProfiles_slitlet" + slit_id +
                 ".pdf")
            plt_name = "_".join(t)
            plt_name = os.path.join(file_path, plt_name)
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        if show_intermediary_figs:
            plt.show()
        plt.close()

        # store tests results in the total dictionary
        total_test_result[slit_id] = slitlet_test_result_list

        # create fits file to hold the calculated correction for each slitlet
        if write_barshadow_files:
            # this is the file to hold the image of the correction values
            outfile_ext = fits.ImageHDU(corrected, name=slit_id)
            outfile.append(outfile_ext)

            # this is the file to hold the image of pipeline-calculated difference values, the comparison
            complfile_ext = fits.ImageHDU(reldiff, name=slit_id)
            complfile.append(complfile_ext)

            # the file is not yet written, indicate that this slit was appended to list to be written
            msg = "Extension corresponing to slitlet " + slit_id + " appended to list to be written into calculated and comparison fits files."
            print(msg)
            log_msgs.append(msg)

    if debug:
        print('total_test_result = ', total_test_result)

    # If all tests passed then pytest will be marked as PASSED, else it will be FAILED
    FINAL_TEST_RESULT = False
    for sl, testlist in total_test_result.items():
        for tdict in testlist:
            for t, tr in tdict.items():
                if tr == "FAILED":
                    FINAL_TEST_RESULT = False
                    msg = "\n * The test of " + t + " for slitlet " + sl + "  FAILED."
                    print(msg)
                    log_msgs.append(msg)
                else:
                    FINAL_TEST_RESULT = True
                    msg = "\n * The test of " + t + " for slitlet " + sl + "  PASSED."
                    print(msg)
                    log_msgs.append(msg)

    if FINAL_TEST_RESULT:
        result_msg = "\n *** Final result for barshadow test will be reported as PASSED *** \n"
        print(result_msg)
        log_msgs.append(result_msg)
    else:
        result_msg = "\n *** Final result for barshadow test will be reported as FAILED *** \n"
        print(result_msg)
        log_msgs.append(result_msg)

    # end the timer
    barshadow_test_end_time = time.time() - barshadow_test_start_time
    if barshadow_test_end_time > 60.0:
        barshadow_test_end_time = barshadow_test_end_time / 60.0  # in minutes
        barshadow_test_tot_time = "* Barshadow validation test took ", repr(
            barshadow_test_end_time) + " minutes to finish."
        if barshadow_test_end_time > 60.0:
            barshadow_test_end_time = barshadow_test_end_time / 60.  # in hours
            barshadow_test_tot_time = "* Barshadow validation test took ", repr(
                barshadow_test_end_time) + " hours to finish."
    else:
        barshadow_test_tot_time = "* Barshadow validation test took ", repr(
            barshadow_test_end_time) + " seconds to finish."
    print(barshadow_test_tot_time)
    log_msgs.append(barshadow_test_tot_time)

    return FINAL_TEST_RESULT, result_msg, log_msgs
for fn, pfx, coord_limits, (vmin, vmax), name, stretch in (
    ("W51Ku_BDarray_continuum_2048_both_uniform.hires.clean.image.fits", 'Ku',
     [
         (
             290.94225,
             14.505832,
         ),
         (
             290.93794,
             14.509374,
         ),
     ], (-1e-1, 3e-1), 'e11_bow', LinearStretch()),
    ("W51C_ACarray_continuum_4096_both_uniform_contsplit.clean.image.fits",
     'C', [(290.9304, 14.5083),
           (290.9194, 14.5189)], (0.01, 9), 'w51main_peak', AsinhStretch()),
    ("W51C_ACarray_continuum_4096_both_uniform_contsplit.clean.image.fits",
     'C', [(290.92462, 14.516962),
           (290.92352, 14.517891)], (2e-2, 2.0), 'e6', AsinhStretch()),
    ("W51C_ACarray_continuum_4096_both_uniform_contsplit.clean.image.fits",
     'C', [(290.90023, 14.523703),
           (290.89873, 14.525156)], (2e-2, 0.3), 'd3', AsinhStretch()),
    ("W51C_ACarray_continuum_4096_both_uniform_contsplit.clean.image.fits",
     'C', [(290.93729, 14.485868),
           (290.93594, 14.487230)], (-2e-2, 0.3), 'e7', AsinhStretch()),
    ("W51C_ACarray_continuum_4096_both_uniform_contsplit.clean.image.fits",
     'C', [(290.93283, 14.506909),
           (290.93200, 14.507676)], (6e-2, 9), 'e1', AsinhStretch()),
    ("W51C_ACarray_continuum_4096_both_uniform_contsplit.clean.image.fits",
     'C', [(290.92402, 14.513314),
           (290.90971, 14.525246)], (-6e-2, 5), 'irs2_C_low', AsinhStretch()),
Beispiel #20
0
def dsscut(name, time, ra, dec, radius, fol):

    plt.rc('text', usetex=False)
    plt.rc('font', family='serif')  #设定字体
    #
    # basic parameters, can be changed in principle
    #
    size = 5  # [arcmin] the size of the retrieved and saved image
    size1 = 5  # [arcmin] the size of the image used for the FC
    pix = 1.008  # approx pixel scale
    #
    # get the cutout image from the ESO archive
    #
    ra1, dec1 = (float(ra), float(dec))
    raS = ra.replace(':', '%3A')
    decS = dec.replace(':', '%3A')
    ra_dec = coord.SkyCoord(ra + ' ' + dec, unit=(u.deg, u.deg))
    ra_hms = '%02.f:%02.f:%05.2f' % (ra_dec.ra.hms)
    dec_dms = '%02.f:%02.f:%05.2f' % (ra_dec.dec.dms[0], abs(
        ra_dec.dec.dms[1]), abs(ra_dec.dec.dms[2]))
    #### size /arcminutes
    link = 'http://archive.eso.org/dss/dss/image?ra=' + raS + '&dec=' + decS + '&equinox=J2000&name=&x=' + str(
        size) + '&y=' + str(
            size
        ) + '&Sky-Survey=DSS2-red&mime-type=download-fits&statsmode=WEBFORM'
    #eg: 'http://archive.eso.org/dss/dss/image?ra=1%3A32%3A32&dec=32%3A12%3A32&equinox=J2000&name=&x=10&y=10&Sky-Survey=DSS2-red&mime-type=download-fits&statsmode=WEBFORM'
    outf_path = fol + '/' + name + '/' + name + '_dss_red.fits'
    csv_outf = fol + '/' + name + '/' + name + '.csv'
    try:
        #with urllib.request.urlopen(link) as response, open(outf, 'wb') as outf:
        response, outf = urllib.request.urlopen(link), open(
            outf_path, 'wb')  #返回一个reponse对象,保存网页信息;新建fits文件
        shutil.copyfileobj(response, outf)  #将网页内容复制到上一步新建的fits文件
        print('\t... image saved to:', outf_path)

        #
        # load image and make a FC
        #
        fh = fits.open(outf_path)

        fim = fh[0].data
        fhe = fh[0].header

        #
        # cut image and apply scale
        #
        imsize_list = [
            10,
        ]
        if (float(radius) * 120 < 3):
            imsize_list.append(3)
        for imsize in imsize_list:
            size1 = imsize
            x1 = int(30 * (size - size1))
            x2 = int(30 * (size + size1))
            y1 = int(30 * (size - size1))
            y2 = int(30 * (size + size1))
            fim = fim[y1:y2, x1:x2]

            fim[np.isnan(fim)] = 0.0
            transform = AsinhStretch() + PercentileInterval(99.7)
            bfim = transform(fim)

            with warnings.catch_warnings(
            ):  #because there are deprecated keywords in the header, no need to write it out
                warnings.simplefilter("ignore")
                global wcs
                wcs = WCS(fhe)

            #
            # produce and save the FC
            #
            fig = plt.figure(2, figsize=(5, 5))
            fig1 = fig.add_subplot(111, aspect='equal')
            plt.imshow(bfim, cmap='gray_r', origin='lower')
            s_world = wcs.wcs_world2pix(np.array([
                [float(ra), float(dec)],
            ]), 1)[0]
            theta = np.linspace(0, 2 * np.pi, 8000)  #???
            x, y = s_world[0] - x1 + np.cos(theta) * float(
                radius) * 3600, s_world[1] - y1 + np.sin(theta) * float(
                    radius) * 3600  #???
            fig1.plot(x, y, color='red', linewidth=0.5)
            try:  #先尝试下载panstarrs星表
                os.system(
                    'wget -nd -nc "https://catalogs.mast.stsci.edu/api/v0.1/panstarrs/dr1/mean?ra='
                    + str(raS) + '&amp;dec=' + str(decS) + '&radius=' +
                    str((size) / 120.) +
                    '&nDetections.gte=1&amp&pagesize=50001&format=csv"  ')
                os.system(
                    'mv ./mean\?ra\=' + str(raS) + '\&amp\;dec\=' + str(decS) +
                    '\&radius\=' + str(size / 120.) +
                    '\&nDetections.gte\=1\&amp\&pagesize\=50001\&format\=csv  ./'
                    + str(csv_outf))
                print('1')
                db_data = rdb(csv_outf)
                print('2')
                #sextab = make_asc_runsex(outf_path)
                print(len(db_data))
                for i in db_data:  #在FC上标星等
                    db_ra = i[1]
                    db_dec = i[2]
                    db_mag = format(i[3], '0.2f')
                    db_world = wcs.wcs_world2pix(
                        np.array([
                            [float(db_ra), float(db_dec)],
                        ]), 1)[0]
                    fig1.text(db_world[0] - x1,
                              db_world[1] - y1,
                              str(db_mag),
                              color='green',
                              fontsize=7)
                    #print(db_world)
                    #fig1.plot(db_world[0], db_world[1], color='red', linewidth=0.5)
                    #fig.text(db_world[0]/(size*60.),db_world[1]/(size*60.),str(format(db_mag, '0.2f')),fontsize=5,color='green')
                #txtb=fig.text(0.06, 0.06, mm, fontsize=10, color='black')
                #txtb.set_path_effects([PathEffects.withStroke(linewidth=0.1, foreground='k')])
            except:  #如果panstarrs星表下载失败,尝试下载skymapper星表(skymapper数据主要是南天的)
                try:
                    os.system(
                        'wget -nd -nc "http://skymapper.anu.edu.au/sm-cone/public/query?RA='
                        + str(raS) + '&DEC=' + str(decS) + '&SR=' +
                        str((size) / 120.) + '&format=csv"  ')
                    os.system('mv ./query\?RA\=' + str(raS) + '\&DEC\=' +
                              str(decS) + '\&SR\=' + str(size / 120.) +
                              '\&format\=csv  ./' + str(csv_outf))
                    #skymapper keys
                    globals()['ram'] = 'raj2000'
                    globals()['ra_err'] = 'e_raj2000'
                    globals()['decm'] = 'dej2000'
                    globals()['dec_err'] = 'e_dej2000'
                    # u g r i z
                    globals()['mm'] = 'r_psf'
                    globals()['mm_err'] = 'e_r_psf'
                    print('1')
                    db_data = rdb(csv_outf)
                    print('2')
                    #sextab = make_asc_runsex(outf_path)
                    print(len(db_data))
                    for i in db_data:
                        db_ra = i[1]
                        db_dec = i[2]
                        db_mag = format(i[3], '0.2f')
                        db_world = wcs.wcs_world2pix(
                            np.array([
                                [float(db_ra), float(db_dec)],
                            ]), 1)[0]
                        fig1.text(db_world[0] - x1,
                                  db_world[1] - y1,
                                  str(db_mag),
                                  color='green',
                                  fontsize=7)
                except:
                    txtb = fig.text(0.45,
                                    0.06,
                                    'NO Panstarrs AND skymapper',
                                    fontsize=10,
                                    color='black')
            fig2 = plt.axes([0.0, 0.0, 0.4,
                             0.12])  #在fig的左下角画一个小框,里面写GRB的信息(时间,坐标,图像视场等)
            fig2.set_facecolor('w')
            txta = fig.text(0.02,
                            0.08,
                            'GRB' + time + '  DSS ' + str(imsize) + '\' x ' +
                            str(imsize) + '\'',
                            fontsize=7,
                            color='black')
            txta = fig.text(0.5, 0.95, 'N', fontsize=18, color='black')
            txta = fig.text(0.01, 0.5, 'E', fontsize=18, color='black')
            txta = fig.text(0.02,
                            0.05,
                            'GRB   ra = ' + ra_hms + ' (' + raS + ')',
                            fontsize=7,
                            color='black')
            txta = fig.text(0.02,
                            0.02,
                            'GRB dec = ' + dec_dms + ' (' + decS + ')',
                            fontsize=7,
                            color='black')
            #txta.set_path_effects([PathEffects.withStroke(linewidth=0.1, foreground='k')])
            #txtb=fig.text(0.9,0.95,'DSS',fontsize=10,color='black')
            #txtb.set_path_effects([PathEffects.withStroke(linewidth=0.1, foreground='k')])
            fig1.add_patch(
                FancyArrowPatch((size1 * 60 - 70 - 10, 20),
                                (size1 * 60 - 10, 20),
                                arrowstyle='-',
                                color='k',
                                linewidth=1.5))
            #fig1.add_patch(FancyArrowPatch((size1*60-15/pix-10,20),(size1*60-10,20),arrowstyle='-',color='black',linewidth=2.0))
            txtc = fig.text(0.9, 0.06, '60\'\'', fontsize=10, color='black')
            txtc.set_path_effects(
                [PathEffects.withStroke(linewidth=0.1, foreground='k')])
            #plt.gca().xaxis.set_major_locator(plt.NullLocator())
            #plt.gca().yaxis.set_major_locator(plt.NullLocator())
            #fig2=plt.axes([0.0, 0.64, 0.1, 0.65])
            #lena = mpimg.imread('a.jpg')
            #lena.shape #(512, 512, 3)
            #plt.imshow(lena)
            plt.gca().xaxis.set_major_locator(plt.NullLocator())
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            plt.subplots_adjust(top=1,
                                bottom=0,
                                left=0,
                                right=1,
                                hspace=0,
                                wspace=0)
            globals()['fname' + str(imsize)] = name + str(imsize) + '_dss.png'
            plt.savefig(fol + '/' + name + '/' +
                        globals()['fname' + str(imsize)],
                        dpi=300,
                        format='PNG')
            fig.clear()

        if (float(radius) * 120 < 3):
            return {
                'pngname10': fol + '/' + name + '/' + globals()['fname10'],
                'pngname3': fol + '/' + name + '/' + globals()['fname3'],
                'dssname': outf_path,
                'dsslink': link
            }
        else:
            return {
                'pngname10': fol + '/' + name + '/' + globals()['fname10'],
                'dssname': outf_path,
                'dsslink': link
            }

    except Exception as e:
        print(str(e))
        return {'dss_err': '-99', 'dsslink': link}
Beispiel #21
0
def data_to_pitch(data_array, pitch_range=[100, 10000], center_pitch=440, zero_point="median",
                  stretch='linear', minmax_percent=None, minmax_value=None, invert=False):
    """
    Map data array to audible pitches in the given range, and apply stretch and scaling
    as required.

    Parameters
    ----------
    data_array : array-like
        Data to map to pitch values. Individual data values should be floats.
    pitch_range : array
        Optional, default [100,10000]. Range of acceptable pitches in Hz.
    center_pitch : float
        Optional, default 440. The pitch in Hz where that the the zero point of the
        data will be mapped to.
    zero_point : str or float
        Optional, default "median". The data value that will be mapped to the center
        pitch. Options are mean, median, or a specified data value (float).
    stretch : str
        Optional, default 'linear'. The stretch to apply to the data array.
        Valid values are: asinh, sinh, sqrt, log, linear
    minmax_percent : array
        Optional. Interval based on a keeping a specified fraction of data values
        (can be asymmetric) when scaling the data. The format is [lower percentile,
        upper percentile], where data values below the lower percentile and above the upper
        percentile are clipped. Only one of minmax_percent and minmax_value should be specified.
    minmax_value : array
        Optional. Interval based on user-specified data values when scaling the data array.
        The format is [min value, max value], where data values below the min value and above
        the max value are clipped.
        Only one of minmax_percent and minmax_value should be specified.
    invert : bool
        Optional, default False.  If True the pitch array is inverted (low pitches become high
        and vice versa).

    Returns
    -------
    response : array
        The normalized data array, with values in given pitch range.
    """
    # Parsing the zero point
    if zero_point in ("med", "median"):
        zero_point = np.median(data_array)
    if zero_point in ("ave", "mean", "average"):
        zero_point = np.mean(data_array)

    # The center pitch cannot be >= max() pitch range, or <= min() of pitch range.
    # If it is, fall back to using the mean of the pitch range provided.
    if center_pitch <= pitch_range[0] or center_pitch >= pitch_range[1]:
        warnings.warn("Given center pitch is outside the pitch range, defaulting to the mean.",
                      InputWarning)
        center_pitch = np.mean(pitch_range)

    if (data_array == zero_point).all():  # All values are the same, no more calculation needed
        return np.full(len(data_array), center_pitch)

    # Normalizing the data_array and adding the zero point (so it can go through the same transform)
    data_array = np.append(np.array(data_array), zero_point)

    # Setting up the transform with the stretch
    if stretch == 'asinh':
        transform = AsinhStretch()
    elif stretch == 'sinh':
        transform = SinhStretch()
    elif stretch == 'sqrt':
        transform = SqrtStretch()
    elif stretch == 'log':
        transform = LogStretch()
    elif stretch == 'linear':
        transform = LinearStretch()
    else:
        raise InvalidInputError("Stretch {} is not supported!".format(stretch))

    # Adding the scaling to the transform
    if minmax_percent is not None:
        transform += AsymmetricPercentileInterval(*minmax_percent)

        if minmax_value is not None:
            warnings.warn("Both minmax_percent and minmax_value are set, minmax_value will be ignored.",
                          InputWarning)
    elif minmax_value is not None:
        transform += ManualInterval(*minmax_value)
    else:  # Default, scale the entire image range to [0,1]
        transform += MinMaxInterval()

    # Performing the transform and then putting it into the pich range
    pitch_array = transform(data_array)

    if invert:
        pitch_array = 1 - pitch_array

    zero_point = pitch_array[-1]
    pitch_array = pitch_array[:-1]

    # In rare cases, the zero-point at this stage might be 0.0.
    # One example is an input array of two values where the median() is the same as the
    # lowest of the two values. In this case, the zero-point is 0.0 and will lead to error
    # (divide by zero). Change to small value to avoid dividing by zero (in reality the choice
    # of zero-point calculation by the user was probably poor, but not in purview to mandate or
    # change user's choice here.  May want to consider providing info back to the user about the
    # distribution of pitches actually used based on their sonification options in some way.
    if zero_point == 0.0:
        zero_point = 1E-6

    if ((1/zero_point)*(center_pitch - pitch_range[0]) + pitch_range[0]) <= pitch_range[1]:
        pitch_array = (pitch_array/zero_point)*(center_pitch - pitch_range[0]) + pitch_range[0]
    else:
        pitch_array = (((pitch_array-zero_point)/(1-zero_point))*(pitch_range[1] - center_pitch) +
                       center_pitch)

    return pitch_array
Beispiel #22
0
def run_msa_flagging_testing(input_file, msa_flagging_threshold=99.5, rate_obj=None,
                             stellarity=None, operability_ref=None, source_type=None,
                             save_figs=False, show_figs=True, debug=False):
    """
    This is the validation function for the msa flagging step.
    :param input_file: string, fits file output from the msa_flagging step
    :param msa_flagging_threshold: float, percentage for all slits with more than 100 pixels
    :param rate_obj: object, the stage 1 pipeline output object
    :param stellarity: float, stellarity number fro 0.0 to 1.0
    :param operability_ref: string, msa failed open - operability - reference file
    :param source_type: string, options are point, extended, unknown
    :param save_figs: boolean
    :param show_figs: boolean
    :param debug: boolean
    :return:
        FINAL_TEST_RESULT: boolean, True if smaller than or equal to threshold
        result_msg: string, message with reason for passing, failing, or skipped
        log_msgs: list, diagnostic strings to be printed in log

    """
    # start the list of messages that will be added to the log file
    log_msgs = []

    # start the timer
    msa_flagging_test_start_time = time.time()

    # get the data model
    if isinstance(input_file, str):
        msaflag = datamodels.open(input_file)
    else:
        msaflag = input_file

    if debug:
        print('got MSA flagging datamodel!')

    # set up generals for all the plots
    font = {'weight': 'normal',
            'size': 12}
    matplotlib.rc('font', **font)

    # plot full image
    fig = plt.figure(figsize=(9, 9))
    norm = ImageNormalize(msaflag.data, vmin=0., vmax=50., stretch=AsinhStretch())
    plt.imshow(msaflag.data, norm=norm, aspect=1.0, origin='lower', cmap='viridis')
    # Show and/or save figures
    detector = msaflag.meta.instrument.detector
    datadir = None
    if save_figs:
        file_basename = os.path.basename(input_file.replace("_msa_flagging.fits", ""))
        datadir = os.path.dirname(input_file)
        t = (file_basename, "MSA_flagging_full_detector.png")
        plt_name = "_".join(t)
        plt_name = os.path.join(datadir, plt_name)
        plt.savefig(plt_name)
        print('Figure saved as: ', plt_name)
    if show_figs:
        plt.show()
    plt.close()

    # read in DQ flags from MSA_flagging product
    # find all pixels that have been flagged by this step  as MSA_FAILED_OPEN -> DQ array value = 536870912
    # https://jwst-pipeline.readthedocs.io/en/latest/jwst/references_general/references_general.html?highlight=536870912#data-quality-flags
    dq_flag = 536870912
    msaflag_1d = msaflag.dq.flatten()
    index_opens = np.squeeze(np.asarray(np.where(msaflag_1d & dq_flag)))
    if debug:
        print("DQ array at 167, 1918: ", msaflag.dq[167, 1918])
        # np.set_printoptions(threshold=sys.maxsize)  # print all elements in array
        print("Index where Failed Open shutters exist: ", np.shape(index_opens), index_opens)

    # execute script that creates an MSA metafile for the failed open shutters
    # read operability reference file
    """
    crds_path = os.environ.get('CRDS_PATH')
    if crds_path is None:
        print("(msa_flagging_testing): The environment variable CRDS_PATH is not defined. To set it, follow the "
              "instructions at: \n"
              "                        https://github.com/spacetelescope/nirspec_pipe_testing_tool")
        exit()
        """

    crds_path = "https://jwst-crds.stsci.edu/unchecked_get/references/jwst/"
    op_ref_file = "jwst_nirspec_msaoper_0001.json"

    if operability_ref is None:
        ref_file = os.path.join(crds_path, op_ref_file)
        urllib.request.urlretrieve(ref_file, op_ref_file)
    else:
        op_ref_file = operability_ref

    if "http" not in op_ref_file:
        if not os.path.isfile(op_ref_file):
            result_msg = "Skipping msa_flagging test because the operability reference file does not exist: " + \
                         op_ref_file
            print(result_msg)
            log_msgs.append(result_msg)
            result = 'skip'
            return result, result_msg, log_msgs

    if debug:
        print("Using this operability reference file: ", op_ref_file)

    with open(op_ref_file) as f:
        msaoper_dict = json.load(f)
    msaoper = msaoper_dict["msaoper"]

    # find the failed open shutters
    failedopens = [(c["Q"], c["x"], c["y"]) for c in msaoper if c["state"] == 'open']
    if debug:
        print("Failed Open shutters: ", failedopens)

    # unpack the list of tuples into separate lists for MSA quadrant, row, column locations
    quads, allrows, allcols = zip(*failedopens)

    # stellarity -- internal lamps are uniform illumination, so set to 0
    # ** if considering a point source, need to change this to 1, or actual value if known
    if source_type is None:
        # srctyapt = fits.getval(input_file, 'SRCTYAPT')  # previously used
        srctyapt = msaflag.meta.target.source_type_apt
    else:
        srctyapt = source_type.upper()
    if stellarity is None:
        if "POINT" in srctyapt:
            stellarity = 1.0
        else:
            stellarity = 0.0
    else:
        stellarity = float(stellarity)

    # create MSA metafile with F/O shutters
    if datadir is not None:
        fometafile = os.path.join(datadir, 'fopens_metafile_msa.fits')
    else:
        fometafile = 'fopens_metafile_msa.fits'
    if not os.path.isfile(fometafile):
        pattnum = msaflag.meta.dither.position_number
        create_metafile_fopens(fometafile, allcols, allrows, quads, stellarity, failedopens,
                               pattnum, save_fig=save_figs, show_fig=show_figs, debug=debug)

    # run assign_wcs on the science exposure using F/O metafile
    # change MSA metafile name in header to match the F/O metafile name
    if isinstance(input_file, str):
        rate_file = input_file.replace("msa_flagging", "rate")
        if not os.path.isfile(rate_file):
            # if a _rate.fits file does not exist try the usual name
            rate_file = os.path.join(datadir, 'final_output_caldet1_'+detector+'.fits')
            if not os.path.isfile(rate_file):
                result_msg = "Skipping msa_flagging test because no rate fits file was found in directory: " + datadir
                print(result_msg)
                log_msgs.append(result_msg)
                result = 'skip'
                return result, result_msg, log_msgs
        if debug:
            print("Will run assign_wcs with new Failed Open fits file on this file: ", rate_file)
        rate_mdl = datamodels.ImageModel(rate_file)
    else:
        rate_mdl = rate_obj

    if debug:
        print("MSA metadata file in initial rate file: ", rate_mdl.meta.instrument.msa_metadata_file)

    rate_mdl.meta.instrument.msa_metadata_file = fometafile
    if debug:
        print("New MSA metadata file in rate file: ", rate_mdl.meta.instrument.msa_metadata_file)

    # force the exp_type of this new model to MSA, even if IFU so that nrs_wcs_set_input pipeline function works
    if "ifu" in msaflag.meta.exposure.type.lower():
        rate_mdl.meta.exposure.type = 'NRS_MSASPEC'

    # run assign_wcs; use +/-0.45 for the y-limits because the default is too big (0.6 including buffer)
    stp = AssignWcsStep()
    awcs_fo = stp.call(rate_mdl, slit_y_low=-0.45, slit_y_high=0.45)

    # get the slits from the F/O processing run
    slits_list = awcs_fo.meta.wcs.get_transform('gwa', 'slit_frame').slits

    # prepare arrays to hold info needed for validation test
    allsizes = np.zeros(len(slits_list))
    allchecks = np.zeros(len(slits_list))

    # loop over the slits and compare pixel bounds with the flagged pixels from the original product
    for i, slit in enumerate(slits_list):
        try:
            name = slit.name
        except AttributeError:
            name = i
        print("\nWorking with slit/slice: ", name)
        if "IFU" not in msaflag.meta.exposure.type.upper():
            print("Slit min and max in y direction: ", slit.ymin, slit.ymax)
        # get the WCS object for this particular slit
        wcs_slice = nirspec.nrs_wcs_set_input(awcs_fo, name)
        # get the bounding box for the 2D subwindow, round to nearest integer, and convert to integer
        bbox = np.rint(wcs_slice.bounding_box)
        bboxint = bbox.astype(int)
        print("bounding box rounded to next integer: ", bboxint)
        i1 = bboxint[0, 0]
        i2 = bboxint[0, 1]
        i3 = bboxint[1, 0]
        i4 = bboxint[1, 1]
        # make array of pixel locations within bounding box
        x, y = np.mgrid[i1:i2, i3:i4]
        index_1d = np.ravel_multi_index([[y], [x]], (2048, 2048))
        # get the slity WCS parameter to find which pixels are located in the actual spectrum
        det2slit = wcs_slice.get_transform('detector', 'slit_frame')
        slitx, slity, _ = det2slit(x, y)
        print("Max value in slity array (ignoring NANs): ", np.nanmax(slity))
        index_trace = np.squeeze(index_1d)[~np.isnan(slity)]
        n_overlap = np.sum(np.isin(index_opens, index_trace))
        overlap_percentage = round(n_overlap/index_trace.size*100., 1)
        if debug:
            print("Size of index_trace= ", index_trace.size)
            print("Size of index_opens=", index_opens.size)
            print("Sum of values found in index_opens and index_trace=", n_overlap)
        msg = 'percentage of F/O trace that was flagged: ' + repr(overlap_percentage)
        print(msg)
        log_msgs.append(msg)
        allchecks[i] = overlap_percentage
        allsizes[i] = index_trace.size

        # show 2D cutouts, with flagged pixels overlaid
        # calculate wavelength, slit_y values for the subwindow
        det2slit = wcs_slice.get_transform('detector', 'slit_frame')
        slitx, slity, _ = det2slit(x, y)

        # extract & display the F/O 2d subwindows from the msa_flagging sci image
        fig = plt.figure(figsize=(19, 19))
        subwin = msaflag.data[i3:i4, i1:i2].copy()
        # set all pixels outside of the nominal shutter length to 0, inside to 1
        subwin[np.isnan(slity.T)] = 0
        subwin[~np.isnan(slity.T)] = 1
        # find the pixels flagged by the msaflagopen step; set them to 1 and everything else to 0 for ease of display
        subwin_dq = msaflag.dq[i3:i4, i1:i2].copy()
        mask = np.zeros(subwin_dq.shape, dtype=bool)
        mask[np.where(subwin_dq & 536870912)] = True
        subwin_dq[mask] = 1
        subwin_dq[~mask] = 0
        # plot the F/O traces
        vmax = np.max(msaflag.data[i3:i4, i1:i2])
        norm = ImageNormalize(msaflag.data[i3:i4, i1:i2], vmin=0., vmax=vmax, stretch=AsinhStretch())
        plt.imshow(msaflag.data[i3:i4, i1:i2], norm=norm, aspect=10.0, origin='lower', cmap='viridis',
                   label='MSA flagging data')
        plt.imshow(subwin, aspect=20.0, origin='lower', cmap='Reds', alpha=0.3, label='Calculated F/O')
        # overplot the flagged pixels in translucent grayscale
        plt.imshow(subwin_dq, aspect=20.0, origin='lower', cmap='gray', alpha=0.3, label='Pipeline F/O')
        if save_figs:
            t = (file_basename, "FailedOpen_detector", detector, "slit", repr(name) + ".png")
            plt_name = "_".join(t)
            plt_name = os.path.join(datadir, plt_name)
            plt.savefig(plt_name)
            print('Figure saved as: ', plt_name)
        if show_figs:
            plt.show()
        plt.close()

    # validation: overlap should be >= msa_flagging_threshold percent for all slits with more than 100 pixels
    FINAL_TEST_RESULT = False
    if not isinstance(msa_flagging_threshold, float):
        msa_flagging_threshold = float(msa_flagging_threshold)
    if (allchecks[allsizes >= 100] >= msa_flagging_threshold).all():
        FINAL_TEST_RESULT = True
    else:
        print("\n * One or more traces show msa_flagging match < ", repr(msa_flagging_threshold))
        print("   See results above per trace. \n")
    if FINAL_TEST_RESULT:
        result_msg = "\n *** Final result for msa_flagging test will be reported as PASSED *** \n"
        print(result_msg)
        log_msgs.append(result_msg)
    else:
        result_msg = "\n *** Final result for msa_flagging test will be reported as FAILED *** \n"
        print(result_msg)
        log_msgs.append(result_msg)

    # end the timer
    msa_flagging_test_end_time = time.time() - msa_flagging_test_start_time
    if msa_flagging_test_end_time >= 60.0:
        msa_flagging_test_end_time = msa_flagging_test_end_time/60.0  # in minutes
        msa_flagging_test_tot_time = "* MSA flagging validation test took ", repr(msa_flagging_test_end_time) + \
                                     " minutes to finish."
        if msa_flagging_test_end_time >= 60.0:
            msa_flagging_test_end_time = msa_flagging_test_end_time/60.  # in hours
            msa_flagging_test_tot_time = "* MSA flagging validation test took ", repr(msa_flagging_test_end_time) + \
                                         " hours to finish."
    else:
        msa_flagging_test_tot_time = "* MSA flagging validation test took ", repr(msa_flagging_test_end_time) + \
                                  " seconds to finish."
    print(msa_flagging_test_tot_time)
    log_msgs.append(msa_flagging_test_tot_time)

    # close the datamodel
    msaflag.close()
    rate_mdl.close()

    return FINAL_TEST_RESULT, result_msg, log_msgs
Beispiel #23
0
    # convert to brightness temperature
    beam = np.pi * hdr['BMAJ'] * hdr['BMIN'] / (4.*np.log(2.))
    beam *= (np.pi/180.)**2
    nu = hdr['CRVAL3']
    Tb = (1e-23 * Inu / beam) * cc**2 / (2.*kk*nu**2)
    print(names[i], nu, 3.6e6*hdr['BMAJ'], 3.6e6*hdr['BMIN'], hdr['BPA'])

    # define coordinate grid
    RA  = 3600. * hdr['CDELT1'] * (np.arange(hdr['NAXIS1'])-(hdr['CRPIX1']-1))
    DEC = 3600. * hdr['CDELT2'] * (np.arange(hdr['NAXIS2'])-(hdr['CRPIX2']-1))
    ext = (np.max(RA)-aoffs[i], np.min(RA)-aoffs[i], 
           np.min(DEC)-doffs[i], np.max(DEC)-doffs[i])

    # plot the image 
    ax = fig.add_subplot(gs[np.floor_divide(i, 4), i%4])
    norm = ImageNormalize(vmin=vmins[i], vmax=vmaxs[i], stretch=AsinhStretch())
    im = ax.imshow(Tb, origin='lower', cmap=cm, 
                   extent=ext, aspect='equal', norm=norm)

    # plot beam
    beam = Ellipse(((xlims[i])[0] + 0.1*np.diff(xlims[i]), 
                    (xlims[i])[1] - 0.1*np.diff(xlims[i])), 
                   hdr['BMAJ']*3600., hdr['BMIN']*3600., 90.-hdr['BPA'])
    beam.set_facecolor('w')
    ax.add_artist(beam)

    # plot the annotations
    ax.annotate(labels[i], xy=(0.06, 0.87), xycoords='axes fraction', 
                horizontalalignment='left', color='w')
    dr = 10.
    ax.plot([(xlims[i])[1] - 0.1*np.diff(xlims[i]), dr / dpc[i] + \
Beispiel #24
0
    obj = name[i]
    sc = SkyCoord.from_name(obj)
    ra = sc.icrs.ra.deg
    dec = sc.icrs.dec.deg
    for hips in hips_list:
        url = 'http://alasky.u-strasbg.fr/hips-image-services/hips2fits?hips={}&width={}&height={}&fov={}&projection=TAN&coordsys=icrs&ra={}&dec={}'.format(
            quote(hips), width, height, fov, ra, dec)
        hdu = fits.open(url)
        wcs = WCS(hdu[0].header)
        plt.figure()
        plt.subplot(projection=wcs)
        im = hdu[0].data
        norm = ImageNormalize(im,
                              interval=MinMaxInterval(),
                              stretch=AsinhStretch())
        file_name = '{}-{}.jpg'.format(obj, hips.replace('/', '_'))
        ap.append(file_name)
        plt.imshow(im, cmap='Greys', norm=norm, origin='lower')
        px, py = wcs.wcs_world2pix(ra, dec, 1)
        plt.scatter(px, py, c='g', s=100)
        plt.title('{} - {}'.format(obj, hips))
        plt.xlabel('RA')
        plt.ylabel('DEC')
        print(file_name)
        plt.savefig(file_name)
    i = i + 1

#%%

import img2pdf
Beispiel #25
0
def create_metafile_fopens(outfile, allcols, allrows, quads, stellarity, failedopens, pattnum,
                           save_fig=False, show_fig=False, debug=False):
    """
    Create the metafile for the failed open shutters according to msa failed open - operability - reference file.
    :param outfile: string, new fits msa meta file
    :param allcols: array
    :param allrows: array
    :param quads: array
    :param stellarity: array
    :param failedopens: array
    :param pattnum: integer
    :param save_fig: boolean
    :param show_fig: boolean
    :param debug: boolean
    :return: nothing
    """
    now = datetime.datetime.now()
    # set up metafile table structure
    # SHUTTER_INFO table
    # slitlet IDs - use one per shutter
    slitlets = np.arange(0, len(allcols)) + 1
    # source IDs - arbitrary for ground data, use slitlet ID
    sources = slitlets
    # metadata IDs - arbitrary for ground data
    metaids = np.full_like(allcols, 1)
    # background shutter?  default to all "N" for ground data
    bkgd = np.full(len(allcols), 'N', dtype=str)
    # shutter state - "OPEN", by definition
    state = np.full(len(allcols), 'OPEN', dtype="<U8")
    # source position in shutter - N/A for ground data, assume centered
    srcx = np.full(len(allcols), 0.)
    srcy = srcx
    dithptind = np.full(len(allcols), pattnum)
    psrc = np.full(len(allcols), 'Y')
    tabcol1 = fits.Column(name='SLITLET_ID', format='I', array=slitlets)
    tabcol2 = fits.Column(name='MSA_METADATA_ID', format='I', array=metaids)
    tabcol3 = fits.Column(name='SHUTTER_QUADRANT', format='I', array=quads)
    tabcol4 = fits.Column(name='SHUTTER_ROW', format='I', array=allrows)
    tabcol5 = fits.Column(name='SHUTTER_COLUMN', format='I', array=allcols)
    tabcol6 = fits.Column(name='SOURCE_ID', format='I', array=sources)
    tabcol7 = fits.Column(name='BACKGROUND', format='A', array=bkgd)
    tabcol8 = fits.Column(name='SHUTTER_STATE', format='4A', array=state)
    tabcol9 = fits.Column(name='ESTIMATED_SOURCE_IN_SHUTTER_X', format='E', array=srcx)
    tabcol10 = fits.Column(name='ESTIMATED_SOURCE_IN_SHUTTER_Y', format='E', array=srcy)
    tabcol11 = fits.Column(name='DITHER_POINT_INDEX', format='I', array=dithptind)
    tabcol12 = fits.Column(name='PRIMARY_SOURCE', format='A', array=psrc)
    hdu2 = fits.BinTableHDU.from_columns(
        [tabcol1, tabcol2, tabcol3, tabcol4, tabcol5, tabcol6, tabcol7, tabcol8, tabcol9, tabcol10, tabcol11, tabcol12],
        name='SHUTTER_INFO')
    # SOURCE_INFO table
    # program ID - arbitrary
    program = np.full_like(allcols, 1)
    # source name - arbitrary
    name = np.full(len(allcols), 'lamp', dtype="<U8")
    # source alias - arbitrary
    alias = np.full(len(allcols), 'foo', dtype="<U8")
    # catalog ID - arbitrary
    catalog = np.full(len(allcols), 'foo', dtype="<U8")
    # RA, DEC - N/A for ground data
    ra = np.full(len(allcols), 0.)
    dec = ra
    # preimage file name - N/A
    preim = np.full(len(allcols), 'foo.fits', dtype="<U8")
    stellarity_arr = np.full(len(allcols), stellarity)
    tabcol1 = fits.Column(name='PROGRAM', format='I', array=program)
    tabcol2 = fits.Column(name='SOURCE_ID', format='I', array=sources)
    tabcol3 = fits.Column(name='SOURCE_NAME', format='4A', array=name)
    tabcol4 = fits.Column(name='ALIAS', format='3A', array=alias)
    tabcol5 = fits.Column(name='RA', format='D', array=ra)
    tabcol6 = fits.Column(name='DEC', format='D', array=dec)
    tabcol7 = fits.Column(name='PREIMAGE_ID', format='8A', array=preim)
    tabcol8 = fits.Column(name='STELLARITY', format='E', array=stellarity_arr)
    hdu3 = fits.BinTableHDU.from_columns([tabcol1, tabcol2, tabcol3, tabcol4, tabcol5, tabcol6, tabcol7, tabcol8],
                                         name='SOURCE_INFO')

    # create image of shutter status for first extension of the metafile
    # do each quadrant first, set value corresponding to each shutter depending on its status
    # (=0 if closed, 1 if open), then concatenate
    q1all = np.zeros(shape=(171, 365))
    q1all[[foid[2] for foid in failedopens if foid[0] == 1], [foid[1] for foid in failedopens if foid[0] == 1]] = 1
    q2all = np.zeros(shape=(171, 365))
    q2all[[foid[2] for foid in failedopens if foid[0] == 2], [foid[1] for foid in failedopens if foid[0] == 2]] = 1
    q3all = np.zeros(shape=(171, 365))
    q3all[[foid[2] for foid in failedopens if foid[0] == 3], [foid[1] for foid in failedopens if foid[0] == 3]] = 1
    q4all = np.zeros(shape=(171, 365))
    q4all[[foid[2] for foid in failedopens if foid[0] == 4], [foid[1] for foid in failedopens if foid[0] == 4]] = 1
    if debug:
        print("Quadrant 1 shape: ", q1all.shape)
        print("open shutters in Quadrant 1: ", np.where(q1all == 1))
        print("open shutters in Quadrant 2: ", np.where(q2all == 1))
        print("open shutters in Quadrant 3: ", np.where(q3all == 1))
        print("open shutters in Quadrant 4: ", np.where(q4all == 1))
    im1 = np.concatenate((q1all, q2all))
    im2 = np.concatenate((q3all, q4all))
    image = np.concatenate((im1, im2), axis=1)

    # set up generals for all the plots
    font = {'weight': 'normal',
            'size': 12}
    matplotlib.rc('font', **font)

    # plotting
    fig = plt.figure(figsize=(9, 9))
    norm = ImageNormalize(image, vmin=0., vmax=1., stretch=AsinhStretch())
    plt.imshow(image, norm=norm, aspect=1.0, origin='lower', cmap='viridis')
    if save_fig:
        datadir = os.path.dirname(outfile)
        plt_name = "FailedOpen_shutters" + ".png"
        plt_name = os.path.join(datadir, plt_name)
        plt.savefig(plt_name)
        print('Figure saved as: ', plt_name)
    if show_fig:
        plt.show()
    plt.close()

    # create fits file
    hdu0 = fits.PrimaryHDU()
    # add necessary keywords to primary header
    hdr = hdu0.header
    hdr.set('ORIGIN', 'STScI', 'institution responsible for creating FITS file')
    hdr.set('TELESCOP', 'JWST', 'telescope used to acquire data')
    hdr.set('INSTRUME', 'NIRSPEC', 'identifier for instrument used to acquire data')
    hdr.set('DATE', now.isoformat())
    hdr.set('FILENAME', outfile, 'name of file')
    hdr.set('PPSDBVER', 'PPSDB999',
            'version of PPS database used')  # N/A for non-OSS ground data, using arbitrary number
    hdr.set('PROGRAM', '1', 'program number')  # arbitrary
    hdr.set('VISIT', '1', 'visit number')  # arbitrary
    hdr.set('OBSERVTN', '1', 'observation number')  # arbitrary
    hdr.set('VISIT_ID', '1', 'visit identifier')  # arbitrary
    hdr.set('PNTG_SEQ', 1, 'pointing sequence number')  # arbitrary
    hdr.set('MSACFG10', 1, 'base 10 nirspec msa_at_pointing.msa_config_id')  # arbitrary
    hdr.set('MSACFG36', '01', 'base 36 version of MSACFG10')  # arbitrary
    hdu1 = fits.ImageHDU(image, name='SHUTTER_IMAGE')
    hdu_all = fits.HDUList([hdu0, hdu1, hdu2, hdu3])
    hdu_all.writeto(outfile)
TODO: Maybe combine this with movies.py?
"""

import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as anim
import matplotlib.colors as colors
from astropy.io import fits
from astropy.visualization import ImageNormalize, AsinhStretch, LinearStretch, \
    LogStretch, PowerDistStretch, PowerStretch, SinhStretch, SqrtStretch, SquaredStretch
import logging
import mkidpipeline as pipe

VALID_STRETCHES = {
    'asinh': AsinhStretch(),
    'linear': LinearStretch(),
    'log': LogStretch(),
    'powerdist': PowerDistStretch(),
    'sinh': SinhStretch(),
    'sqrt': SqrtStretch(),
    'squared': SquaredStretch()
}

log = logging.getLogger(__name__)


def read_fits(file, wvl_bin=None):
    """
    Reads a fits file into temporal and spectral cubes. Temporal cube has dimensions of [time, x, y] and spectral cube
    has dimensions of [wvl, x, y].
    def create_jpg_from_fits(self, fits_filepath, outdir):
        '''
        Overriding instrument default function
        Tile images horizontally in order from left to right.
        Use DETSEC keyword to figure out data order/position
        '''

        # Only create jpeg for level 1 processing
        if self.level != 1:
            return
        
        #open
        hdus = fits.open(fits_filepath, ignore_missing_end=True)

        #needed hdr vals
        hdr0 = hdus[0].header

        if hdr0['KOAID'].startswith('DF'):
            super().create_jpg_from_fits(fits_filepath, outdir)
            return
            
        binning  = hdr0['BINNING'].split(',')
        precol   = int(hdr0['PRECOL'])   // int(binning[0])
        postpix  = int(hdr0['POSTPIX'])  // int(binning[0])
        preline  = int(hdr0['PRELINE'])  // int(binning[1])
        postline = int(hdr0['POSTLINE']) // int(binning[1])

        #get extension order (uses DETSEC keyword)
        ext_order = Deimos.get_ext_data_order(hdus)
        assert ext_order, "ERROR: Could not determine extended data order"

        #loop thru extended headers in order, create png and add to list in order
        interval = ZScaleInterval()
        vmin = None
        vmax = None
#        alldata = None
        # DEIMOS has 2 rows of 4 CCDs each
        alldata = [[], []]
        for row, extData in enumerate(ext_order):
            if len(extData) == 0: continue
            for i, ext in enumerate(extData):
                data = hdus[ext].data
                hdr  = hdus[ext].header
                if 'ndarray' not in str(type(data)): continue

                #calc bias array from postpix area
                sh = data.shape
                x1 = 0
                x2 = sh[0]
                y1 = sh[1] - postpix + 1
                y2 = sh[1] - 1
                bias = np.median(data[x1:x2, y1:y2], axis=1)
                bias = np.array(bias, dtype=np.int64)

                #subtract bias
                data = data - bias[:,None]

                #get min max of each ext (not including pre/post pixels)
                #NOTE: using sample box that is 90% of full area
                #todo: should we take an average min/max of each ext for balancing?
                sh = data.shape
                x1 = int(preline          + (sh[0] * 0.10))
                x2 = int(sh[0] - postline - (sh[0] * 0.10))
                y1 = int(precol           + (sh[1] * 0.10))
                y2 = int(sh[1] - postpix  - (sh[1] * 0.10))
                tmp_vmin, tmp_vmax = interval.get_limits(data[x1:x2, y1:y2])
                if vmin == None or tmp_vmin < vmin: vmin = tmp_vmin
                if vmax == None or tmp_vmax > vmax: vmax = tmp_vmax
                if vmin < 0: vmin = 0

                #remove pre/post pix columns
                data = data[:,precol:data.shape[1]-postpix]

                #flip data left/right
                #NOTE: This should come after removing pre/post pixels
                ds = Deimos.get_detsec_data(hdr['DETSEC'])
                if ds and ds[0] > ds[1]:
                    data = np.fliplr(data)
                if ds and ds[2] > ds[3]:
                    data = np.flipud(data)

                #concatenate horizontally
                if i==0:
                    alldata[row] = data
                else   :
                    alldata[row] = np.append(alldata[row], data, axis=1)

        # If alldata has 2 rows, then vertically stack them
        # else take the one row

        s0 = len(alldata[0])
        s1 = len(alldata[1])

        if s0 > 0 and s1 > 0:
            alldata = np.concatenate((alldata[0], alldata[1]), axis=0)
            # Need to rotate final stitched image
#            alldata = ndimage.rotate(alldata, -90, axes=(0, 1))
            alldata = np.rot90(alldata, 1, axes=(1, 0))
        elif s0 > 0:
            alldata = alldata[0]
        elif s1 > 0:
            alldata = alldata[1]

        #filepath vars
        basename = os.path.basename(fits_filepath).replace('.fits', '')
        out_filepath = f'{outdir}/{basename}.jpg'

        #bring in min/max by 2% to help ignore large areas of black or overexposed spots
        #todo: this does not achieve what we want
        # minmax_adjust = 0.02
        # vmin += int((vmax - vmin) * minmax_adjust)
        # vmax -= int((vmax - vmin) * minmax_adjust)

        #normalize, stretch and create jpg
        norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=AsinhStretch())
        dpi = 100
        width_inches  = alldata.shape[1] / dpi
        height_inches = alldata.shape[0] / dpi
        fig = plt.figure(figsize=(width_inches, height_inches), frameon=False, dpi=dpi)
        ax = fig.add_axes([0, 0, 1, 1]) #this forces no border padding; bbox_inches='tight' doesn't really work
        plt.axis('off')
        plt.imshow(alldata, cmap='gray', origin='lower', norm=norm)
        # DEIMOS jpegs are large, let's reduce the size using dpi (default is 100)
        plt.savefig(out_filepath, quality=92, dpi=(50))
        plt.close()
Beispiel #28
0
def create_thumbnails(canvas, fig_photo_objects, id_value, id_value_index,
                      stretch):
    #global ID_values
    global RA_values
    global DEC_values
    global ra_dec_size_value
    global image_all
    global image_hdu_all
    global image_wcs_all
    global image_flux_value_err_cat
    global all_images_filter_name
    global number_images
    global SNR_values

    # Let's associate the selected object with it's RA and DEC
    # Create the object thumbnails.
    #idx_cat = np.where(ID_values == id_value)[0]
    idx_cat = id_value_index
    objRA = RA_values[idx_cat]
    objDEC = DEC_values[idx_cat]

    cosdec_center = math.cos(objDEC * 3.141593 / 180.0)

    # Set the position of the object
    position = SkyCoord(str(objRA) + 'd ' + str(objDEC) + 'd', frame='fk5')
    size = u.Quantity((ra_dec_size_value, ra_dec_size_value), u.arcsec)

    fig_photo_objects = np.empty(0, dtype='object')
    start_time_full = time.time()
    for i in range(0, number_images):

        image = image_all[i].data
        image_hdu = image_hdu_all[i]
        image_wcs = image_wcs_all[i]

        if (all_images_filter_name[i] == 'HST_F814W'):
            image_wcs.sip = None

        if (image_flux_value_err_cat[idx_cat, i] > -9999):
            # Make the cutout

            start_time = time.time()
            image_cutout = Cutout2D(image, position, size, wcs=image_wcs)
            print "   Running Cutout2D took %s seconds" % (time.time() -
                                                           start_time)

            # Create the wcs axes
            plt.clf()
            start_time = time.time()
            fig = plt.figure(figsize=(1.5, 1.5))
            ax3 = fig.add_axes([0, 0, 1, 1], projection=image_cutout.wcs)
            ax3.text(0.51,
                     0.96,
                     all_images_filter_name[i].split('_')[1],
                     transform=ax3.transAxes,
                     fontsize=12,
                     fontweight='bold',
                     ha='center',
                     va='top',
                     color='black')
            ax3.text(0.5,
                     0.95,
                     all_images_filter_name[i].split('_')[1],
                     transform=ax3.transAxes,
                     fontsize=12,
                     fontweight='bold',
                     ha='center',
                     va='top',
                     color='white')
            if (SNR_values[idx_cat, i] > -100):
                ax3.text(0.96,
                         0.06,
                         'SNR = ' + str(round(SNR_values[idx_cat, i], 2)),
                         transform=ax3.transAxes,
                         fontsize=12,
                         fontweight='bold',
                         horizontalalignment='right',
                         color='black')
                ax3.text(0.95,
                         0.05,
                         'SNR = ' + str(round(SNR_values[idx_cat, i], 2)),
                         transform=ax3.transAxes,
                         fontsize=12,
                         fontweight='bold',
                         horizontalalignment='right',
                         color='white')
            else:
                ax3.text(0.96,
                         0.06,
                         'SNR < -100',
                         transform=ax3.transAxes,
                         fontsize=12,
                         fontweight='bold',
                         horizontalalignment='right',
                         color='black')
                ax3.text(0.95,
                         0.05,
                         'SNR < -100',
                         transform=ax3.transAxes,
                         fontsize=12,
                         fontweight='bold',
                         horizontalalignment='right',
                         color='white')
            print "    Plotting the text and SNR values took %s seconds" % (
                time.time() - start_time)

            # Set the color map
            plt.set_cmap('gray')

            start_time = time.time()
            indexerror = 0
            # Normalize the image using the min-max interval and a square root stretch
            thumbnail = image_cutout.data
            if (stretch == 'AsinhStretch'):
                try:
                    norm = ImageNormalize(thumbnail,
                                          interval=ZScaleInterval(),
                                          stretch=AsinhStretch())
                except IndexError:
                    indexerror = 1
                except UnboundLocalError:
                    indexerror = 1
            if (stretch == 'LogStretch'):
                try:
                    norm = ImageNormalize(thumbnail,
                                          interval=ZScaleInterval(),
                                          stretch=LogStretch(100))
                except IndexError:
                    indexerror = 1
                except UnboundLocalError:
                    indexerror = 1
            if (stretch == 'LinearStretch'):
                try:
                    norm = ImageNormalize(thumbnail,
                                          interval=ZScaleInterval(),
                                          stretch=LinearStretch())
                except IndexError:
                    indexerror = 1
                except UnboundLocalError:
                    indexerror = 1
            print "     ImageNormalize took %s seconds" % (time.time() -
                                                           start_time)

            start_time = time.time()
            if (indexerror == 0):
                ax3.imshow(thumbnail,
                           origin='lower',
                           aspect='equal',
                           norm=norm)
            else:
                ax3.imshow(thumbnail, origin='lower', aspect='equal')
            print "      Running imshow took %s seconds" % (time.time() -
                                                            start_time)

            start_time = time.time()
            if (i <= 5):
                fig_x, fig_y = 20 + (175 * i), 500
            if ((i > 5) & (i <= 11)):
                fig_x, fig_y = 20 + (175 * (i - 6)), 675
            if ((i > 11) & (i <= 17)):
                fig_x, fig_y = 20 + (175 * (i - 12)), 900

            # Keep this handle alive, or else figure will disappear
            fig_photo_objects = np.append(
                fig_photo_objects, draw_figure(canvas, fig,
                                               loc=(fig_x, fig_y)))
            plt.close('all')
            print "         Drawing the thumnails to the figures took %s seconds" % (
                time.time() - start_time)
    print "**** FULLY PLOTTING THUMBNAILS TOOK %s seconds" % (time.time() -
                                                              start_time_full)

    return fig_photo_objects
def plot_fits(fname,
              ax=None,
              cmap='inferno',
              range=None,
              p0=[0, 0],
              pixel_size_x=None,
              pixel_size_y=None,
              dpc=None,
              vmin=None,
              vmax=None,
              rsqaure=False,
              n_up=None,
              title=None,
              coronagraph_mask=None,
              fct='pcolormesh',
              beam=None,
              autoshift=False,
              PA=None,
              stretch=AsinhStretch(),
              image_fct=None):
    """
    fname : float
        path to file

    ax : None | axes
        where to plot the figure, create if None

    cmap : colormap
        which colormap to pass to pcolormesh

    range : None | float
        which range in mas (or au if dpc given) to plot around the center

    mask : None | float
        what size of center circle to plot

    p0 : list of two floats
        where the center is located in mas

    pixel_size : float | None
        size of a pixel in mas
        set to None to try and read it from the fits file

    dpc : float
        distance in parsec

    vmin, vmax : None | float
        which lower and upper bound to use
        will figure something it out if  `None`

    rsquare : bool
        if true, multiply the intensity with r**2

    n_up : None | int
        if int: upscale the image to that scale

    title : str
        title to plot in top left corner

    coronagraph_mask : None | float
        size of central circle to e.g. cover coronograph region
        in mas (or au if dpc given)

    beam : list
        beam size (FWHM) for convolution in mas

    fct : str
        which bound method of the axes object to use for plotting.
        For transparent pdfs, it's for example better to use the slower pcolor
        while pcolormesh is much faster.

    autoshift : bool
        to put the brightest pixel in the center

    PA : None | float
        if float: rotate by this amount

    stretch : astropy.visualization.BaseStretch instance
        use LinearStretch for linear scaling, default for the
        DSHARP images is AsinhStretch

    image_fct : None | callable
        pass a function with signature fct(x, y, image) that will be called
        to process the image.
    """
    from scipy.ndimage import rotate
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.figure

    hdulist = fits.open(fname)
    header = hdulist[0].header
    Snu = np.squeeze(hdulist[0].data)
    Snu[np.isnan(Snu)] = 0.0

    if pixel_size_x is None:
        if 'CDELT1' in header:
            pixel_size_x = header['CDELT1'] * 3600. * 1e3
        else:
            pixel_size_x = 1
    if pixel_size_y is None:
        if 'CDELT2' in header:
            pixel_size_y = header['CDELT2'] * 3600. * 1e3
        else:
            pixel_size_y = 1

    if PA is not None:
        Snu = rotate(Snu, PA, reshape=False, mode='constant', cval=0.0)

    x = np.arange(Snu.shape[0], dtype=float) * abs(pixel_size_x)  # in mas
    y = np.arange(Snu.shape[1], dtype=float) * abs(pixel_size_y)  # in mas

    if dpc is not None:
        x *= dpc * 1e-3
        y *= dpc * 1e-3
        if beam is not None:
            beam = np.array(beam) * dpc * 1e-3

    x -= x[-1] / 2.0
    y -= y[-1] / 2.0
    x -= p0[0]
    y -= p0[1]

    if autoshift:
        cy, cx = np.unravel_index(Snu.argmax(), Snu.shape)
        x -= x[cx]
        y -= y[cy]

    if range is not None:
        ix0 = np.abs(x + range).argmin()
        ix1 = np.abs(x - range).argmin()
        iy0 = np.abs(y + range).argmin()
        iy1 = np.abs(y - range).argmin()

        x = x[ix0:ix1 + 1]
        y = y[iy0:iy1 + 1]
        Snu = Snu[iy0:iy1 + 1, ix0:ix1 + 1]

    if n_up is not None:
        print(f'scaling from {Snu.shape} to ({n_up},{n_up})')

        f = interp2d(x, y, Snu)

        x = np.linspace(x[0], x[-1], n_up)
        y = np.linspace(y[0], y[-1], n_up)
        Snu = np.maximum(0.0, f(x, y))

    std = Snu[:20, :20].std()
    if vmin is None:
        vmin = 1.5 * std
    if vmax is None:
        # vmax = 100 * std
        vmax = 0.75 * Snu.max()

    print('{}: vmin = {:.2g}, vmax = {:.2g}'.format(title, vmin, vmax))

    norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch)

    if beam is not None:
        print('beam = {}'.format(beam))
        sigma = beam / (2 * np.sqrt(2 * np.log(2)))
        Snu = ndimage.gaussian_filter(Snu, sigma)

    if image_fct is not None:
        Snu = image_fct(x, y, Snu)

    getattr(ax, fct)(
        -x,
        y,
        Snu,
        cmap=cmap,
        norm=norm,
        rasterized=True,
        # edgecolor=(1.0, 1.0, 1.0, 0.3), linewidth=0.0015625
    )

    ax.set_aspect('equal')
    if dpc is None:
        ax.set_xlabel(r'$\Delta$RA [mas]')
        ax.set_ylabel(r'$\Delta$DEC [mas]')
    else:
        ax.set_xlabel(r'$\Delta$RA [au]')
        ax.set_ylabel(r'$\Delta$DEC [au]')

    if range is not None:
        ax.set_xlim([range, -range])
        ax.set_ylim([-range, range])
    else:
        ax.set_xlim(ax.get_xlim()[::-1])

    if title is not None:
        ax.text(0.05,
                0.95,
                title,
                color='w',
                transform=ax.transAxes,
                verticalalignment='top')

    if coronagraph_mask is not None:
        ax.add_artist(plt.Circle((0, 0), radius=coronagraph_mask, color='0.5'))

    ax.set_facecolor('k')

    return fig, ax
Beispiel #30
0
def data_to_pitch(data_array,
                  pitch_range=[100, 10000],
                  center_pitch=440,
                  zero_point="median",
                  stretch='linear',
                  minmax_percent=None,
                  minmax_value=None,
                  invert=False):
    """
    Map data array to audible pitches in the given range, and apply stretch and scaling
    as required.

    Parameters
    ----------
    data_array : array-like
        Data to map to pitch values. Individual data values should be floats.
    pitch_range : array
        Optional, default [100,10000]. Range of acceptable pitches in Hz. 
    center_pitch : float
        Optional, default 440. The pitch in Hz where that the the zero point of the 
        data will be mapped to.
    zero_point : str or float
        Optional, default "median". The data value that will be mapped to the center
        pitch. Options are mean, median, or a specified data value (float).
    stretch : str
        Optional, default 'linear'. The stretch to apply to the data array.
        Valid values are: asinh, sinh, sqrt, log, linear
    minmax_percent : array
        Optional. Interval based on a keeping a specified fraction of data values (can be asymmetric) 
        when scaling the data. The format is [lower percentile, upper percentile], where data
        values below the lower percentile and above the upper percentile are clipped.
        Only one of minmax_percent and minmax_value should be specified.
    minmax_value : array
        Optional. Interval based on user-specified data values when scaling the data array.
        The format is [min value, max value], where data values below the min value and above
        the max value are clipped.
        Only one of minmax_percent and minmax_value should be specified.
    invert : bool
        Optional, default False.  If True the pitch array is inverted (low pitches become high 
        and vice versa).

    Returns
    -------
    response : array
        The normalized data array, with values in given pitch range.
    """

    # Parsing the zero point
    if zero_point in ("med", "median"):
        zero_point = np.median(data_array)
    if zero_point in ("ave", "mean", "average"):
        zero_point = np.mean(data_array)

    if (data_array == zero_point
        ).all():  # All values are the same, no more calculation needed
        return np.full(len(data_array), zero_point)

    # Normalizing the data_array and adding the zero point (so it can go through the same transform)
    data_array = np.append(np.array(data_array), zero_point)

    # Setting up the transform with the stretch
    if stretch == 'asinh':
        transform = AsinhStretch()
    elif stretch == 'sinh':
        transform = SinhStretch()
    elif stretch == 'sqrt':
        transform = SqrtStretch()
    elif stretch == 'log':
        transform = LogStretch()
    elif stretch == 'linear':
        transform = LinearStretch()
    else:
        raise InvalidInputError("Stretch {} is not supported!".format(stretch))

    # Adding the scaling to the transform
    if minmax_percent is not None:
        transform += AsymmetricPercentileInterval(*minmax_percent)

        if minmax_value is not None:
            warnings.warn(
                "Both minmax_percent and minmax_value are set, minmax_value will be ignored.",
                InputWarning)
    elif minmax_value is not None:
        transform += ManualInterval(*minmax_value)
    else:  # Default, scale the entire image range to [0,1]
        transform += MinMaxInterval()

    # Performing the transform and then putting it into the pich range
    pitch_array = transform(data_array)

    if invert:
        pitch_array = 1 - pitch_array

    zero_point = pitch_array[-1]
    pitch_array = pitch_array[:-1]

    if ((1 / zero_point) *
        (center_pitch - pitch_range[0]) + pitch_range[0]) <= pitch_range[1]:
        pitch_array = (pitch_array / zero_point) * (
            center_pitch - pitch_range[0]) + pitch_range[0]
    else:
        pitch_array = (
            (pitch_array - zero_point) /
            (1 - zero_point)) * (pitch_range[1] - center_pitch) + center_pitch

    return pitch_array