Ejemplo n.º 1
0
def main():

    filename = get_stack_filename()

    ad = astrodata.open(filename)

    data = ad[0].data
    mask = ad[0].mask
    header = ad[0].hdr

    masked_data = np.ma.masked_where(mask, data, copy=True)

    palette = copy(plt.cm.viridis)
    palette.set_bad('gray')

    norm_factor = visualization.ImageNormalize(
        masked_data,
        stretch=visualization.LinearStretch(),
        interval=visualization.ZScaleInterval(),
    )

    fig, ax = plt.subplots(subplot_kw={'projection': wcs.WCS(header)})

    ax.imshow(masked_data,
              cmap=palette,
              vmin=norm_factor.vmin,
              vmax=norm_factor.vmax)

    ax.set_title(os.path.basename(filename))
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    fig.savefig(filename.replace('.fits', '.png'))
    plt.show()
Ejemplo n.º 2
0
def diagnostic_plots(output_dir, hdu, image_wcs, detected_sources,
                     catalog_sources_xy, matched_stars, offset_x, offset_y):
    """Function to output plots used to assess and debug the astrometry
    performed on the reference image"""

    norm = visualization.ImageNormalize(hdu[0].data, \
                        interval=visualization.ZScaleInterval())

    fig = plt.figure(1)
    plt.imshow(hdu[0].data, origin='lower', cmap=plt.cm.viridis, norm=norm)
    plt.plot(detected_sources[:,1],detected_sources[:,2],'o',markersize=2,\
             markeredgewidth=1,markeredgecolor='b',markerfacecolor='None')
    plt.plot(catalog_sources_xy[:, 0], catalog_sources_xy[:, 1], 'r+')
    plt.xlabel('X pixel')
    plt.ylabel('Y pixel')
    plt.savefig(path.join(output_dir, 'reference_detected_sources_pixels.png'))
    plt.close(1)

    fig = plt.figure(1)
    fig.add_subplot(111, projection=image_wcs)
    plt.imshow(hdu[0].data, origin='lower', cmap=plt.cm.viridis, norm=norm)
    plt.plot(detected_sources[:,1],detected_sources[:,2],'o',markersize=2,\
             markeredgewidth=1,markeredgecolor='b',markerfacecolor='None')
    plt.plot(catalog_sources_xy[:, 0], catalog_sources_xy[:, 1], 'r+')
    plt.xlabel('RA [J2000]')
    plt.ylabel('Dec [J2000]')
    plt.savefig(path.join(output_dir, 'reference_detected_sources_world.png'))
    plt.close(1)

    fig = plt.figure(2)
    plt.subplot(211)
    plt.subplots_adjust(left=0.125,
                        bottom=0.1,
                        right=0.9,
                        top=0.95,
                        wspace=0.1,
                        hspace=0.3)
    plt.hist((matched_stars[:, 5] - matched_stars[:, 2]), 50)
    (xmin, xmax, ymin, ymax) = plt.axis()
    plt.plot([offset_x, offset_x], [ymin, ymax], 'r-')
    plt.xlabel('(Detected-catalog) X pixel')
    plt.ylabel('Frequency')
    plt.subplot(212)
    plt.hist((matched_stars[:, 4] - matched_stars[:, 1]), 50)
    (xmin, xmax, ymin, ymax) = plt.axis()
    plt.plot([offset_y, offset_y], [ymin, ymax], 'r-')
    plt.xlabel('(Detected-catalog) Y pixel')
    plt.ylabel('Frequency')
    plt.savefig(path.join(output_dir, 'astrometry_separations.png'))
    plt.close(1)
Ejemplo n.º 3
0
def main():
    args = _parse_args()
    filename = args.filename

    ad = astrodata.open(filename)

    data = ad[0].data
    mask = ad[0].mask
    header = ad[0].hdr

    if args.mask:
        masked_data = np.ma.masked_where(mask, data, copy=True)
    else:
        masked_data = data

    palette = copy(plt.cm.viridis)
    palette.set_bad('Gainsboro')

    norm_factor = visualization.ImageNormalize(
        masked_data,
        stretch=visualization.LinearStretch(),
        interval=visualization.ZScaleInterval(),
    )

    fig = plt.figure(num=filename)
    ax = fig.subplots(subplot_kw={"projection": wcs.WCS(header)})

    print(norm_factor.vmin)
    print(norm_factor.vmax)
    ax.imshow(
        masked_data,
        cmap=palette,
        #vmin=norm_factor.vmin,
        #vmax=norm_factor.vmax,
        vmin=750.,
        vmax=900.,
        origin='lower')

    ax.set_title(os.path.basename(filename))

    ax.coords[0].set_axislabel('Right Ascension')
    ax.coords[0].set_ticklabel(fontsize='small')

    ax.coords[1].set_axislabel('Declination')
    ax.coords[1].set_ticklabel(rotation='vertical', fontsize='small')

    fig.tight_layout(rect=[0.05, 0, 1, 1])
    fig.savefig(os.path.basename(filename.replace('.fits', '.png')))
    plt.show()
Ejemplo n.º 4
0
def build_psf_mask(setup, psf_size, diagnostics=False):
    """Function to construct a mask for the PSF of a single star, which
    is a 2D image array with 1.0 at all pixel locations within the PSF and 
    zero everywhere outside it."""

    half_psf = int(psf_size)
    half_psf2 = half_psf * half_psf

    pxmax = 2 * half_psf + 1
    pymax = pxmax

    psf_mask = np.ones([pymax, pxmax])

    pxmin = -half_psf
    pxmax = half_psf + 1
    pymin = -half_psf
    pymax = half_psf + 1

    for dx in range(pxmin, pxmax, 1):

        dx2 = dx * dx

        for dy in range(pymin, pymax, 1):

            if (dx2 + dy * dy) > half_psf2:
                psf_mask[dx + half_psf, dy + half_psf] = 0.0

    if diagnostics == True:
        fig = plt.figure(1)

        norm = visualization.ImageNormalize(psf_mask, \
                            interval=visualization.ZScaleInterval())

        plt.imshow(psf_mask, origin='lower', cmap=plt.cm.viridis, norm=norm)

        plt.xlabel('X pixel')

        plt.ylabel('Y pixel')

        plt.savefig(path.join(setup.red_dir, 'psf_mask.png'))

        plt.close(1)

    return psf_mask
Ejemplo n.º 5
0
def main():

    filename = get_stack_filename()

    ad = astrodata.open(filename)

    fig = plt.figure(num=filename, figsize=(7, 4.5))
    fig.suptitle(os.path.basename(filename), y=0.97)

    axs = fig.subplots(1, len(ad), sharey=True)

    palette = copy(plt.cm.viridis)
    palette.set_bad("Gainsboro", 1.0)

    norm = visualization.ImageNormalize(
        np.dstack([ext.data for ext in ad]),
        stretch=visualization.LinearStretch(),
        interval=visualization.ZScaleInterval()
    )

    print(norm.vmin)
    print(norm.vmax)
    for i in range(len(ad)):

        axs[i].imshow(
            # np.ma.masked_where(ad[i].mask > 0, ad[i].data),
            ad[i].data,
            #norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax),
            norm=colors.Normalize(vmin=750, vmax=900),
            origin="lower",
            cmap=palette,
        )

        axs[i].set_xlabel('d{:02d}'.format(i+1))
        axs[i].set_xticks([])

    axs[i].set_yticks([])

    fig.tight_layout(rect=[0, 0, 1, 1], w_pad=0.05)

    fig.savefig(os.path.basename(filename.replace('.fits', '.png')))
    plt.show()
Ejemplo n.º 6
0
def get_plot_norm(data, vmin=None, vmax=None, zscale=False, scale='linear'):
    from astropy import visualization as viz
    from astropy.visualization.mpl_normalize import ImageNormalize

    if zscale:
        interval = viz.ZScaleInterval()
        vmin, vmax = interval.get_limits(data.filled(0))

    if scale == 'linear':
        stretch = viz.LinearStretch
    elif scale == 'log':
        stretch = viz.LogStretch
    elif scale in ('asinh', 'arcsinh'):
        stretch = viz.AsinhStretch
    elif scale == 'sqrt':
        stretch = viz.SqrtStretch
    else:
        raise ValueError('Unknown scale: {}'.format(scale))

    norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch(), clip=False)

    return norm
Ejemplo n.º 7
0
def get_plot_norm(data, vmin=None, vmax=None, zscale=False, scale='linear'):
    from astropy import visualization as viz
    from astropy.visualization.mpl_normalize import ImageNormalize

    # Choose vmin and vmax automatically?
    if zscale:
        interval = viz.ZScaleInterval()
        if data.dtype == np.float64:
            try:
                vmin, vmax = interval.get_limits(data.filled(np.nan))
            except Exception:
                # catch failure on all NaN
                if np.all(np.isnan(data.filled(np.nan))):
                    vmin, vmax = (np.nan, np.nan)
                else:
                    raise
        else:
            vmin, vmax = interval.get_limits(data.filled(0))

    # How are values between vmin and vmax mapped to corresponding
    # positions along the colorbar?
    if scale == 'linear':
        stretch = viz.LinearStretch
    elif scale == 'log':
        stretch = viz.LogStretch
    elif scale in ('asinh', 'arcsinh'):
        stretch = viz.AsinhStretch
    elif scale == 'sqrt':
        stretch = viz.SqrtStretch
    else:
        raise ValueError('Unknown scale: {}'.format(scale))

    # Create an object that will be used to map pixel values
    # in the range vmin..vmax to normalized colormap indexes.
    norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch(), clip=False)

    return norm
Ejemplo n.º 8
0
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from astropy.io import fits
from PIL import Image
import astropy.visualization as visualization
plt.style.use('mystyle-2.mplstyle')
zscale = visualization.ZScaleInterval()

# load no-star image
maskpath = '../Images/star_mask.fits'  # image with central star and blooming removed
hdulist = fits.open(maskpath)
starmask = hdulist[0].data
hdulist.close()

imshow(starmask)

#%%
# read header
filename = "../Images/A1_mosaic.fits"  # with frame but no star
hdulist = fits.open(filename)
original = hdulist[0].data
header = hdulist[0].header
hdulist.close()

#%%
# load original image
filename = "A1_mosaic_nostar.fits"  # with frame but no star
hdulist = fits.open(filename)
image = hdulist[0].data
hdulist.close()
Ejemplo n.º 9
0
def mask_stars(setup,
               ref_image,
               ref_star_catalog,
               psf_mask,
               diagnostics=False):
    """Function to create a mask for an image, returning a 2D image that has 
    a value of 0.0 in every pixel that falls within the PSF of a star, and
    a value of 1.0 everywhere else."""

    psf_size = psf_mask.shape[0]

    half_psf = int(psf_size / 2.0)

    pxmin = 0
    pxmax = psf_mask.shape[0]
    pymin = 0
    pymax = psf_mask.shape[1]

    star_mask = np.ones(ref_image.shape)

    for j in range(0, len(ref_star_catalog), 1):

        xstar = int(ref_star_catalog[j, 1])
        ystar = int(ref_star_catalog[j, 2])

        xmin = xstar - half_psf
        xmax = xstar + half_psf + 1
        ymin = ystar - half_psf
        ymax = ystar + half_psf + 1

        px1 = pxmin
        px2 = pxmax + 1
        py1 = pymin
        py2 = pymax + 1

        if xmin < 0:

            px1 = abs(xmin)
            xmin = 0

        if ymin < 0:

            py1 = abs(ymin)
            ymin = 0

        if xmax > ref_image.shape[0]:

            px2 = px2 - (xmax - ref_image.shape[0] + 1)
            xmax = ref_image.shape[0] + 1

        if ymax > ref_image.shape[1]:

            py2 = py2 - (ymax - ref_image.shape[1] + 1)
            ymax = ref_image.shape[1] + 1

        star_mask[ymin:ymax,xmin:xmax] = star_mask[ymin:ymax,xmin:xmax] + \
                                            psf_mask[py1:py2,px1:px2]

    idx_mask = np.where(star_mask > 1.0)
    star_mask[idx_mask] = np.nan

    idx = np.isnan(star_mask)
    idx = np.where(idx == False)

    star_mask[idx] = 0.0
    star_mask[idx_mask] = 1.0

    masked_ref_image = np.ma.masked_array(ref_image, mask=star_mask)

    if diagnostics == True:

        fig = plt.figure(2)

        norm = visualization.ImageNormalize(star_mask, \
                            interval=visualization.ZScaleInterval())

        plt.imshow(star_mask, origin='lower', cmap=plt.cm.viridis, norm=norm)

        plt.plot(ref_star_catalog[:, 1], ref_star_catalog[:, 2], 'r+')

        plt.xlabel('X pixel')

        plt.ylabel('Y pixel')

        plt.colorbar()

        plt.savefig(path.join(setup.red_dir, 'ref_star_mask.png'))

        plt.close(2)

        fig = plt.figure(3)

        norm = visualization.ImageNormalize(masked_ref_image, \
                            interval=visualization.ZScaleInterval())

        plt.imshow(masked_ref_image,
                   origin='lower',
                   cmap=plt.cm.viridis,
                   norm=norm)

        plt.plot(ref_star_catalog[:, 1], ref_star_catalog[:, 2], 'r+')

        plt.xlabel('X pixel')

        plt.ylabel('Y pixel')

        plt.colorbar()

        plt.savefig(path.join(setup.red_dir, 'masked_ref_image.png'))

        plt.close(3)

    return masked_ref_image
Ejemplo n.º 10
0
 def set_normalization(self,
                       stretch=None,
                       interval=None,
                       stretchkwargs={},
                       intervalkwargs={},
                       perm_linear=None):
     if stretch is None:
         if self.stretch is None:
             stretch = 'linear'
         else:
             stretch = self.stretch
     if isinstance(stretch, str):
         print(stretch,
               ' '.join([f'{k}={v}' for k, v in stretchkwargs.items()]))
         if self.data is None:  #can not calculate objects yet
             self.stretch_kwargs = stretchkwargs
         else:
             kwargs = self.prepare_kwargs(
                 self.stretch_kws_defaults[stretch], self.stretch_kwargs,
                 stretchkwargs)
             if perm_linear is not None:
                 perm_linear_kwargs = self.prepare_kwargs(
                     self.stretch_kws_defaults['linear'], perm_linear)
                 print(
                     'linear', ' '.join([
                         f'{k}={v}' for k, v in perm_linear_kwargs.items()
                     ]))
                 if stretch == 'asinh':  # arg: a=0.1
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.AsinhStretch(**kwargs))
                 elif stretch == 'contrastbias':  # args: contrast, bias
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.ContrastBiasStretch(**kwargs))
                 elif stretch == 'histogram':
                     stretch = vis.CompositeStretch(
                         vis.HistEqStretch(self.data, **kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'log':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LogStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'powerdist':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.PowerDistStretch(**kwargs))
                 elif stretch == 'power':  # args: a
                     stretch = vis.CompositeStretch(
                         vis.PowerStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'sinh':  # args: a=0.33
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SinhStretch(**kwargs))
                 elif stretch == 'sqrt':
                     stretch = vis.CompositeStretch(
                         vis.SqrtStretch(),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'square':
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SquaredStretch())
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
             else:
                 if stretch == 'linear':  # args: slope=1, intercept=0
                     stretch = vis.LinearStretch(**kwargs)
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
     self.stretch = stretch
     if interval is None:
         if self.interval is None:
             interval = 'zscale'
         else:
             interval = self.interval
     if isinstance(interval, str):
         print(interval,
               ' '.join([f'{k}={v}' for k, v in intervalkwargs.items()]))
         kwargs = self.prepare_kwargs(self.interval_kws_defaults[interval],
                                      self.interval_kwargs, intervalkwargs)
         if self.data is None:
             self.interval_kwargs = intervalkwargs
         else:
             if interval == 'minmax':
                 interval = vis.MinMaxInterval()
             elif interval == 'manual':  # args: vmin, vmax
                 interval = vis.ManualInterval(**kwargs)
             elif interval == 'percentile':  # args: percentile, n_samples
                 interval = vis.PercentileInterval(**kwargs)
             elif interval == 'asymetric':  # args: lower_percentile, upper_percentile, n_samples
                 interval = vis.AsymmetricPercentileInterval(**kwargs)
             elif interval == 'zscale':  # args: nsamples=1000, contrast=0.25, max_reject=0.5, min_npixels=5, krej=2.5, max_iterations=5
                 interval = vis.ZScaleInterval(**kwargs)
             else:
                 raise ValueError('Unknown interval:' + interval)
     self.interval = interval
     if self.img is not None:
         self.img.set_norm(
             vis.ImageNormalize(self.data,
                                interval=self.interval,
                                stretch=self.stretch,
                                clip=True))
Ejemplo n.º 11
0
std = np.std(fwhm)
med = np.median(fwhm)
low = 2
high = 2
clip_mask = ((fwhm < (med - (std * low))) | (fwhm > (med + (std * high))))
data = data[~clip_mask]
print(data)

# extract the positions of all sources
positions = np.transpose((data["X_IMAGE"], data["Y_IMAGE"]))

# sextractor doesn't index from zero?
positions = positions - 1

# define image normalisation
norm = aviz.ImageNormalize(img, interval=aviz.ZScaleInterval())

fig = pyplot.figure()

if (hist_dist == 1) and ("star" in fi_name):

    gs = gridspec.GridSpec(2, 1)
    ax2 = pyplot.subplot(gs[1, 0])

    #plot the distribution of FWHM
    ax2.hist(data["FWHM_IMAGE"] * pix_scale, bins="auto")
    ax2.axvline(np.median(data["FWHM_IMAGE"] * pix_scale),
                label="median = {:.3f}".format(
                    np.median(data["FWHM_IMAGE"] * pix_scale)),
                c='r')
    ax2.axvline(np.mean(data["FWHM_IMAGE"] * pix_scale),
Ejemplo n.º 12
0
def test_find_psf_companion_stars():
    """Function to test the identification of stars that neighbour a PSF star 
    from the reference catalogue."""

    setup = pipeline_setup.pipeline_setup({'red_dir': TEST_DIR})

    log = logs.start_stage_log(cwd, 'test_find_psf_companions')

    log.info(setup.summary())

    reduction_metadata = metadata.MetaData()
    reduction_metadata.load_a_layer_from_file(setup.red_dir,
                                              'pyDANDIA_metadata.fits',
                                              'reduction_parameters')

    star_catalog_file = os.path.join(TEST_DATA, 'star_catalog.fits')

    ref_star_catalog = catalog_utils.read_ref_star_catalog_file(
        star_catalog_file)

    log.info('Read in catalog of ' + str(len(ref_star_catalog)) + ' stars')

    psf_idx = 18
    psf_x = 189.283172607
    psf_y = 9.99084472656
    psf_size = 8.0

    stamp_dims = (20, 20)

    comps_list = psf.find_psf_companion_stars(setup, psf_idx, psf_x, psf_y,
                                              psf_size, ref_star_catalog, log,
                                              stamp_dims)

    assert len(comps_list) > 0

    for l in comps_list:
        log.info(repr(l))

    image_file = os.path.join(TEST_DATA,
                              'lsc1m005-fl15-20170701-0144-e91_cropped.fits')

    image = fits.getdata(image_file)

    corners = psf.calc_stamp_corners(psf_x,
                                     psf_y,
                                     stamp_dims[1],
                                     stamp_dims[0],
                                     image.shape[1],
                                     image.shape[0],
                                     over_edge=True)

    stamp = image[corners[2]:corners[3], corners[0]:corners[1]]

    log.info('Extracting PSF stamp image')

    fig = plt.figure(1)

    norm = visualization.ImageNormalize(stamp, \
                interval=visualization.ZScaleInterval())

    plt.imshow(stamp, origin='lower', cmap=plt.cm.viridis, norm=norm)

    x = []
    y = []
    for j in range(0, len(comps_list), 1):
        x.append(comps_list[j][1])
        y.append(comps_list[j][2])

    plt.plot(x, y, 'r+')

    plt.axis('equal')

    plt.xlabel('X pixel')

    plt.ylabel('Y pixel')

    plt.savefig(os.path.join(TEST_DATA, 'psf_companion_stars.png'))

    plt.close(1)

    logs.close_log(log)
Ejemplo n.º 13
0
def make_fov_image(fov, pngfn=None, **kwargs):
    stretch = kwargs.get('stretch', 'linear')
    interval = kwargs.get('interval', 'zscale')
    imrange = kwargs.get('imrange')
    contrast = kwargs.get('contrast', 0.25)
    ccdplotorder = ['CCD2', 'CCD4', 'CCD1', 'CCD3']
    if interval == 'rms':
        try:
            losig, hisig = imrange
        except:
            losig, hisig = (2.5, 5.0)
    #
    cmap = kwargs.get('cmap', 'viridis')
    cmap = plt.get_cmap(cmap)
    cmap.set_bad('w', 1.0)
    w = 0.4575
    h = 0.455
    rc('text', usetex=False)
    fig = plt.figure(figsize=(6, 6.5))
    cax = fig.add_axes([0.1, 0.04, 0.8, 0.01])
    ims = [fov[ccd]['im'] for ccd in ccdplotorder]
    allpix = np.ma.array(ims).flatten()
    stretch = {
        'linear': vis.LinearStretch(),
        'histeq': vis.HistEqStretch(allpix),
        'asinh': vis.AsinhStretch(),
    }[stretch]
    if interval == 'zscale':
        iv = vis.ZScaleInterval(contrast=contrast)
        vmin, vmax = iv.get_limits(allpix)
    elif interval == 'rms':
        nsample = 1000 // nbin
        background = sigma_clip(allpix[::nsample], iters=3, sigma=2.2)
        m, s = background.mean(), background.std()
        vmin, vmax = m - losig * s, m + hisig * s
    elif interval == 'fixed':
        vmin, vmax = imrange
    else:
        raise ValueError
    norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch)
    for n, (im, ccd) in enumerate(zip(ims, ccdplotorder)):
        if im.ndim == 3:
            im = im.mean(axis=-1)
        x = fov[ccd]['x']
        y = fov[ccd]['y']
        i = n % 2
        j = n // 2
        pos = [0.0225 + i * w + i * 0.04, 0.05 + j * h + j * 0.005, w, h]
        ax = fig.add_axes(pos)
        _im = ax.imshow(im,
                        origin='lower',
                        extent=[x[0, 0], x[0, -1], y[0, 0], y[-1, 0]],
                        norm=norm,
                        cmap=cmap,
                        interpolation=kwargs.get('interpolation', 'nearest'))
        if fov['coordsys'] == 'sky':
            ax.set_xlim(x.max(), x.min())
        else:
            ax.set_xlim(x.min(), x.max())
        ax.set_ylim(y.min(), y.max())
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        if n == 0:
            cb = fig.colorbar(_im, cax, orientation='horizontal')
            cb.ax.tick_params(labelsize=9)
    tstr = fov.get('file', '') + ' ' + fov.get('objname', '')
    title = kwargs.get('title', tstr)
    title = title[-60:]
    fig.text(0.5, 0.99, title, ha='center', va='top', size=12)
    if pngfn is not None:
        plt.savefig(pngfn)
        plt.close(fig)
def main():

    # filename = 'S20170505S0102_flatCorrected.fits'
    filename = get_filename()

    ad = astrodata.open(filename)
    print(ad.info())

    fig = plt.figure(num=filename, figsize=(8, 8))
    fig.suptitle('{}'.format(filename))

    palette = copy(plt.cm.viridis)
    palette.set_bad('w', 1.0)

    norm = visualization.ImageNormalize(
        np.dstack([ad[i].data for i in range(4)]),
        stretch=visualization.LinearStretch(),
        interval=visualization.ZScaleInterval())

    ax1 = fig.add_subplot(224)
    ax1.imshow(np.ma.masked_where(ad[0].mask > 0, ad[0].data),
               norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax),
               origin='lower',
               cmap=palette)

    ax1.annotate('d1', (20, 20), color='white')
    ax1.set_xlabel('x [pixels]')
    ax1.set_ylabel('y [pixels]')

    ax2 = fig.add_subplot(223)
    ax2.imshow(np.ma.masked_where(ad[1].mask > 0, ad[1].data),
               norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax),
               origin='lower',
               cmap=palette)

    ax2.annotate('d2', (20, 20), color='white')
    ax2.set_xlabel('x [pixels]')
    ax2.set_ylabel('y [pixels]')

    ax3 = fig.add_subplot(221)
    ax3.imshow(np.ma.masked_where(ad[2].mask > 0, ad[2].data),
               norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax),
               origin='lower',
               cmap=palette)

    ax3.annotate('d3', (20, 20), color='white')
    ax3.set_xlabel('x [pixels]')
    ax3.set_ylabel('y [pixels]')

    ax4 = fig.add_subplot(222)
    ax4.imshow(np.ma.masked_where(ad[3].mask > 0, ad[3].data),
               norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax),
               origin='lower',
               cmap=palette)

    ax4.annotate('d4', (20, 20), color='white')
    ax4.set_xlabel('x [pixels]')
    ax4.set_ylabel('y [pixels]')

    fig.tight_layout(rect=[0, 0.03, 1, 0.95])

    plt.savefig(filename.replace('.fits', '.png'))
    plt.show()
Ejemplo n.º 15
0
    def run(self, doqa=True, debug=False, show=False):
        """
        Main driver for tracing arc lines

        Code flow:

            #. Extract an arc spectrum down the center of each slit/order
            #. Loop on slits/orders
                #. Trace and fit the arc lines (This is done twice, once
                   with trace_crude as the tracing crutch, then again
                   with a PCA model fit as the crutch).
                #. Repeat trace.
                #.  2D Fit to the offset from slitcen
                #. Save

        Args:
            doqa (bool):
            debug (bool):
            show (bool):

        Returns:
            :class:`WaveTilts`:

        """
        # Extract the arc spectra for all slits
        self.arccen, self.arccen_bpm = self.extract_arcs()

        # TODO: Leave for now.  Used for debugging
#        self.par['rm_continuum'] = True
#        debug = True
#        show = True

        # Subtract arc continuum
        _mstilt = self.mstilt.image.copy()
        if self.par['rm_continuum']:
            continuum = self.model_arc_continuum(debug=debug)
            _mstilt -= continuum
            if debug:
                # TODO: Put this into a function
                vmin, vmax = visualization.ZScaleInterval().get_limits(_mstilt)
                w,h = plt.figaspect(1)
                fig = plt.figure(figsize=(3*w,h))
                ax = fig.add_axes([0.15/3, 0.1, 0.8/3, 0.8])
                ax.imshow(self.mstilt.image, origin='lower', interpolation='nearest',
                          aspect='auto', vmin=vmin, vmax=vmax)
                ax.set_title('MasterArc')
                ax = fig.add_axes([1.15/3, 0.1, 0.8/3, 0.8])
                ax.imshow(continuum, origin='lower', interpolation='nearest',
                          aspect='auto', vmin=vmin, vmax=vmax)
                ax.set_title('Continuum')
                ax = fig.add_axes([2.15/3, 0.1, 0.8/3, 0.8])
                ax.imshow(_mstilt, origin='lower', interpolation='nearest',
                          aspect='auto', vmin=vmin, vmax=vmax)
                ax.set_title('MasterArc - Continuum')
                plt.show()

        # Final tilts image
        self.final_tilts = np.zeros(self.shape_science,dtype=float)
        max_spat_dim = (np.asarray(self.par['spat_order']) + 1).max()
        max_spec_dim = (np.asarray(self.par['spec_order']) + 1).max()
        self.coeffs = np.zeros((max_spec_dim, max_spat_dim,self.slits.nslits))
        self.spat_order = np.zeros(self.slits.nslits, dtype=int)
        self.spec_order = np.zeros(self.slits.nslits, dtype=int)

        # TODO sort out show methods for debugging
        if show:
            viewer,ch = ginga.show_image(self.mstilt.image*(self.slitmask > -1),chname='tilts')

        # Loop on all slits
        for slit_idx, slit_spat in enumerate(self.slits.spat_id):
            if self.tilt_bpm[slit_idx]:
                continue
            #msgs.info('Computing tilts for slit {0}/{1}'.format(slit, self.slits.nslits-1))
            msgs.info('Computing tilts for slit {0}/{1}'.format(slit_idx, self.slits.nslits))
            # Identify lines for tracing tilts
            msgs.info('Finding lines for tilt analysis')
            self.lines_spec, self.lines_spat \
                    = self.find_lines(self.arccen[:,slit_idx], self.slitcen[:,slit_idx],
                                      slit_idx,
                                      bpm=self.arccen_bpm[:,slit_idx], debug=debug)

            if self.lines_spec is None:
                self.slits.mask[slit_idx] = self.slits.bitmask.turn_on(self.slits.mask[slit_idx], 'BADTILTCALIB')
                continue

            thismask = self.slitmask == slit_spat

            # Performs the initial tracing of the line centroids as a
            # function of spatial position resulting in 1D traces for
            # each line.
            msgs.info('Trace the tilts')
            self.trace_dict = self.trace_tilts(_mstilt, self.lines_spec, self.lines_spat,
                                               thismask, self.slitcen[:, slit_idx])

            # TODO: Show the traces before running the 2D fit

            if show:
                ginga.show_tilts(viewer, ch, self.trace_dict)

            self.spat_order[slit_idx] = self._parse_param(self.par, 'spat_order', slit_idx)
            self.spec_order[slit_idx] = self._parse_param(self.par, 'spec_order', slit_idx)
            # 2D model of the tilts, includes construction of QA
            # NOTE: This also fills in self.all_fit_dict and self.all_trace_dict
            coeff_out = self.fit_tilts(self.trace_dict, thismask, self.slitcen[:,slit_idx],
                                       self.spat_order[slit_idx], self.spec_order[slit_idx],
                                       slit_idx,
                                       doqa=doqa, show_QA=show, debug=show)
            self.coeffs[:self.spec_order[slit_idx]+1,:self.spat_order[slit_idx]+1,slit_idx] = coeff_out

            # TODO: Need a way to assess the success of fit_tilts and
            # flag the slit if it fails

            # Tilts are created with the size of the original slitmask,
            # which corresonds to the same binning as the science
            # images, trace images, and pixelflats etc.
            self.tilts = tracewave.fit2tilts(self.slitmask_science.shape, coeff_out,
                                             self.par['func2d'])
            # Save to final image
            thismask_science = self.slitmask_science == slit_spat
            self.final_tilts[thismask_science] = self.tilts[thismask_science]

        if debug:
            # TODO: Add this to the show method?
            vmin, vmax = visualization.ZScaleInterval().get_limits(_mstilt)
            plt.imshow(_mstilt, origin='lower', interpolation='nearest', aspect='auto',
                       vmin=vmin, vmax=vmax)
            for slit in self.slit_idx:
                spat = self.all_trace_dict[slit]['tilts_spat']
                spec = self.all_trace_dict[slit]['tilts']
                spec_fit = self.all_trace_dict[slit]['tilts_fit']
                in_fit = self.all_trace_dict[slit]['tot_mask']
                not_fit = np.invert(in_fit) & (spec > 0)
                fit_rej = in_fit & np.invert(self.all_trace_dict[slit]['fit_mask'])
                fit_keep = in_fit & self.all_trace_dict[slit]['fit_mask']
                plt.scatter(spat[not_fit], spec[not_fit], color='C1', marker='.', s=30, lw=0)
                plt.scatter(spat[fit_rej], spec[fit_rej], color='C3', marker='.', s=30, lw=0)
                plt.scatter(spat[fit_keep], spec[fit_keep], color='k', marker='.', s=30, lw=0)
                with_fit = np.invert(np.all(np.invert(fit_keep), axis=0))
                for t in range(in_fit.shape[1]):
                    if not with_fit[t]:
                        continue
                    l, r = np.nonzero(in_fit[:,t])[0][[0,-1]]
                    plt.plot(spat[l:r+1,t], spec_fit[l:r+1,t], color='k')
            plt.show()

        # Record the Mask
        bpmtilts = np.zeros_like(self.slits.mask, dtype=self.slits.bitmask.minimum_dtype())
        for flag in ['BADTILTCALIB']:
            bpm = self.slits.bitmask.flagged(self.slits.mask, flag)
            if np.any(bpm):
                bpmtilts[bpm] = self.slits.bitmask.turn_on(bpmtilts[bpm], flag)

        # Build and return DataContainer
        tilts_dict = {'coeffs':self.coeffs,
                      'func2d':self.par['func2d'], 'nslit':self.slits.nslits,
                      'spat_order':self.spat_order, 'spec_order':self.spec_order,
                      'spat_id':self.slits.spat_id, 'bpmtilts': bpmtilts,
                      'spat_flexure': self.spat_flexure,
                      'PYP_SPEC': self.spectrograph.spectrograph}
        return WaveTilts(**tilts_dict)
Ejemplo n.º 16
0
def test_subtract_companions_from_psf_stamps():
    """Function to test the function which removes companion stars from the 
    surrounds of a PSF star in a PSF star stamp image."""

    setup = pipeline_setup.pipeline_setup({'red_dir': TEST_DIR})

    log = logs.start_stage_log(cwd, 'test_subtract_companions')

    log.info(setup.summary())

    reduction_metadata = metadata.MetaData()
    reduction_metadata.load_a_layer_from_file(setup.red_dir,
                                              'pyDANDIA_metadata.fits',
                                              'reduction_parameters')

    star_catalog_file = os.path.join(TEST_DATA, 'star_catalog.fits')

    ref_star_catalog = catalog_utils.read_ref_star_catalog_file(
        star_catalog_file)

    log.info('Read in catalog of ' + str(len(ref_star_catalog)) + ' stars')

    image_file = os.path.join(TEST_DATA,
                              'lsc1m005-fl15-20170701-0144-e91_cropped.fits')

    image = fits.getdata(image_file)

    psf_idx = [248]
    psf_x = 257.656
    psf_y = 121.365

    stamp_centres = np.array([[psf_x, psf_y]])

    psf_size = 10.0
    stamp_dims = (20, 20)

    stamps = psf.cut_image_stamps(setup, image, stamp_centres, stamp_dims)

    if len(stamps) == 0:

        log.info(
            'ERROR: No PSF stamp images returned.  PSF stars too close to the edge?'
        )

    else:

        for i, s in enumerate(stamps):

            fig = plt.figure(1)

            norm = visualization.ImageNormalize(s.data, \
                            interval=visualization.ZScaleInterval())

            plt.imshow(s.data, origin='lower', cmap=plt.cm.viridis, norm=norm)

            plt.xlabel('X pixel')

            plt.ylabel('Y pixel')

            plt.axis('equal')

            plt.savefig(
                os.path.join(setup.red_dir,
                             'psf_star_stamp' + str(i) + '.png'))

            plt.close(1)

        psf_model = psf.Moffat2D()
        x_cen = psf_size + (psf_x - int(psf_x))
        y_cen = psf_size + (psf_x - int(psf_y))
        psf_radius = 8.0
        psf_params = [
            103301.241291, x_cen, y_cen, 226.750731765, 13004.8930993,
            103323.763627
        ]
        psf_model.update_psf_parameters(psf_params)

        sky_model = psf.ConstantBackground()
        sky_model.background_parameters.constant = 1345.0

        clean_stamps = psf.subtract_companions_from_psf_stamps(
            setup,
            reduction_metadata,
            log,
            ref_star_catalog,
            psf_idx,
            stamps,
            stamp_centres,
            psf_model,
            sky_model,
            diagnostics=True)

    logs.close_log(log)
Ejemplo n.º 17
0
def get_image(data, fmt='JPEG', norm='percentile', lo=None, hi=None,
              zcontrast=0.25, nsamples=1000, krej=2.5, max_iterations=5,
              stretch='linear', a=None, bias=0.5, contrast=1, cmap=None,
              dpi=100, **kwargs):
    u"""
    Return a byte array containing image in the given format

    Image scaling is done using `~astropy.visualization`. It includes
    normalization of the input data (mapping to [0, 1]) and stretching -
    optional non-linear mapping [0, 1] -> [0, 1] for contrast enhancement.
    A colormap can be applied to the normalized data. Conversion to the target
    image format is done by matplotlib or Pillow.

    :param array_like data: input 2D image data
    :param str fmt: output image format
    :param str norm: data normalization mode::
        "manual": lower and higher clipping limits are set explicitly
        "minmax": limits are set to the minimum and maximum data values
        "percentile" (default): limits are set based on the specified fraction
            of pixels
        "zscale": use IRAF ZScale algorithm
    :param int | float lo::
        for ``norm`` == "manual", lower data limit
        for ``norm`` == "percentile", lower percentile clipping value,
            defaulting to 10
        for ``norm`` == "zscale", lower limit on the number of rejected pixels,
            defaulting to 5
    :param int | float hi::
        for ``norm`` == "manual", upper data limit
        for ``norm`` == "percentile", upper percentile clipping value,
            defaulting to 98
        for ``norm`` == "zscale", upper limit on the number of rejected pixels,
            defaulting to data.size/2
    :param float zcontrast: for ``norm`` == "zscale", the scaling factor,
        0 < zcontrast < 1, defaulting to 0.25
    :param int nsamples: for ``norm`` == "zscale", the number of points in
        the input array for determining scale factors, defaulting to 1000
    :param float krej: for ``norm`` == "zscale", the sigma clipping factor,
        defaulting to 2.5
    :param int max_iterations: for ``norm`` == "zscale", the maximum number
        of rejection iterations, defaulting to 5
    :param str stretch: [0, 1] → [0, 1] mapping mode::
        "asinh": hyperbolic arcsine stretch y = asinh(x/a)/asinh(1/a)
        "contrast": linear bias/contrast-based stretch
            y = (x - bias)*contrast + 0.5
        "exp": exponential stretch y = (a^x - 1)/(a - 1)
        "histeq": histogram equalization stretch
        "linear" (default): direct mapping
        "log": logarithmic stretch y = log(ax + 1)/log(a + 1)
        "power": power stretch y = x^a
        "sinh": hyperbolic sine stretch y = sinh(x/a)/sinh(1/a)
        "sqrt": square root stretch y = √x
        "square": power stretch y = x^2
    :param float a: non-linear stretch parameter::
        for ``stretch`` == "asinh", the point of transition from linear to
            logarithmic behavior, 0 < a <= 1, defaulting to 0.1
        for ``stretch`` == "exp", base of the exponent, a != 1, defaulting to
            1000
        for ``stretch`` == "log", base of the logarithm minus 1, a > 0,
            defaulting to 1000
        for ``stretch`` == "power", the power index, defaulting to 3
        for ``stretch`` == "sinh", a > 0, defaulting to 1/3
    :param float bias: for ``stretch`` == "contrast", the bias parameter,
        defaulting to 0.5
    :param float contrast: for ``stretch`` == "contrast", the contrast
        parameter, defaulting to 1
    :param str cmap: optional matplotlib colormap name, defaulting
        to grayscale; when a non-grayscale colormap is specified,
        the conversion is always done by matplotlib, regardless of the
        availability of Pillow; see https://matplotlib.org/users/colormaps.html
        for more info on matplotlib colormaps and
            [name for name in matplotlib.cd.cmap_d.keys()
             if not name.endswith('_r')]
        to list the available colormap names
    :param int dpi: target image resolution in dots per inch
    :param kwargs: optional format-specific keyword arguments passed to Pillow,
        e.g. "quality" for JPEG; see
        `https://pillow.readthedocs.io/en/stable/handbook/
        image-file-formats.html`_

    :return: a bytes object containing the image in the given format
    :rtype: bytes
    """
    data = asanyarray(data)

    # Normalize image data
    if norm == 'manual':
        if lo is None:
            raise ValueError(
                'Missing lower clipping boundary for norm="manual"')
        if hi is None:
            raise ValueError(
                'Missing upper clipping boundary for norm="manual"')
    elif norm == 'minmax':
        lo, hi = data.min(), data.max()
    elif norm == 'percentile':
        if lo is None:
            lo = 10
        elif not 0 <= lo <= 100:
            raise ValueError(
                'Lower clipping percentile must be in the [0,100] range')
        if hi is None:
            hi = 98
        elif not 0 <= hi <= 100:
            raise ValueError(
                'Upper clipping percentile must be in the [0,100] range')
        if hi < lo:
            raise ValueError(
                'Upper clipping percentile must be greater or equal to '
                'lower percentile')
        lo, hi = percentile(data, [lo, hi])
    elif norm == 'zscale':
        if lo is None:
            lo = 5
        if hi is None:
            hi = 0.5
        else:
            hi /= data.size
        lo, hi = apy_vis.ZScaleInterval(
            nsamples, zcontrast, hi, lo, krej, max_iterations).get_limits(data)
    else:
        raise ValueError('Unknown normalization mode "{}"'.format(norm))
    data = clip((data - lo)/(hi - lo), 0, 1)

    # Stretch the data
    if stretch == 'asinh':
        if a is None:
            a = 0.1
        apy_vis.AsinhStretch(a)(data, out=data)
    elif stretch == 'contrast':
        if bias != 0.5 or contrast != 1:
            apy_vis.ContrastBiasStretch(contrast, bias)(data, out=data)
    elif stretch == 'exp':
        if a is None:
            a = 1000
        apy_vis.PowerDistStretch(a)(data, out=data)
    elif stretch == 'histeq':
        apy_vis.HistEqStretch(data)(data, out=data)
    elif stretch == 'linear':
        pass
    elif stretch == 'log':
        if a is None:
            a = 1000
        apy_vis.LogStretch(a)(data, out=data)
    elif stretch == 'power':
        if a is None:
            a = 3
        apy_vis.PowerStretch(a)(data, out=data)
    elif stretch == 'sinh':
        if a is None:
            a = 1/3
        apy_vis.SinhStretch(a)(data, out=data)
    elif stretch == 'sqrt':
        apy_vis.SqrtStretch()(data, out=data)
    elif stretch == 'square':
        apy_vis.SquaredStretch()(data, out=data)
    else:
        raise ValueError('Unknown stretch mode "{}"'.format(stretch))

    buf = BytesIO()
    try:
        # Choose the backend for making an image
        if cmap is None:
            cmap = 'gray'
        if cmap == 'gray':
            try:
                # noinspection PyPackageRequirements,PyPep8Naming
                from PIL import Image as pil_image
            except ImportError:
                pil_image = None
        else:
            pil_image = None

        if pil_image is not None:
            # Use Pillow for grayscale output if available; flip the image to
            # match the bottom-to-top FITS convention and convert from [0,1] to
            # unsigned byte
            pil_image.fromarray(
                (data[::-1]*255 + 0.5).astype(uint8),
            ).save(buf, fmt, dpi=(dpi, dpi), **kwargs)
        else:
            # Use matplotlib for non-grayscale colormaps or if PIL is not
            # available
            # noinspection PyPackageRequirements
            from matplotlib import image as mpl_image
            if fmt.lower() == 'png':
                # PNG images are saved upside down by matplotlib, regardless of
                # the origin parameter
                data = data[::-1]
            # noinspection PyTypeChecker
            mpl_image.imsave(
                buf, data, cmap=cmap, format=fmt, origin='lower', dpi=dpi)

        return buf.getvalue()
    finally:
        buf.close()
Ejemplo n.º 18
0
plt.subplot(1, 4, 1)
plt.title('H-beta Flux')
norm = v.ImageNormalize(Hbeta_flux_map,
                        interval=v.ManualInterval(vmin=0, vmax=500),
                        stretch=v.LogStretch(10))
im = plt.imshow(np.ma.filled(Hbeta_flux_map, fill_value=-10),
                origin='lower',
                norm=norm,
                cmap='Greys')
plt.colorbar(im)

plt.subplot(1, 4, 2)
plt.title('Residuals')
norm = v.ImageNormalize(residuals_map,
                        interval=v.ZScaleInterval(),
                        stretch=v.LinearStretch())
im = plt.imshow(np.ma.filled(residuals_map, fill_value=0),
                origin='lower',
                norm=norm,
                cmap='Greys')
plt.gca().set_xticklabels([])
plt.gca().set_yticklabels([])
plt.colorbar(im)

plt.subplot(1, 4, 3)
plt.title('Velocity')
norm = v.ImageNormalize(velocity_map,
                        interval=v.ManualInterval(vmin=-60, vmax=60),
                        stretch=v.LinearStretch())
im = plt.imshow(np.ma.filled(velocity_map, fill_value=0),
Ejemplo n.º 19
0
def run_canvis(field, ID, run, verbose=False, debugmode=False):
    print('\n#########################')
    print('#  CANVIS HAS STARTED   #')
    print('#########################\n')

    translist_path = '/home/fstars/MARY4OZ/transients/transients_coo.txt'
    print('CANVIS will make a gif for field %s Mary ID number %s.\n' %
          (field, ID))
    print('CANVIS will do this by reading %s\n' % translist_path)
    '''Look for the CCD number, RA and DEC of the Mary ID entry that matches inputs.'''
    with open(translist_path) as f:
        for line in f:
            line = line.split()
            #print(line[0])
            if int(line[0]) == ID:
                ra = str(line[1])
                dec = str(line[2])
                #field = str(line[3])
                ccd_num = str(line[6])

    if debugmode:
        print(ra, dec, ccd_num)
    '''Given the CCD number, RA, DEC, go through all the data on this field with this CCD 
    and extract postage stamps around the given RA and DEC.'''
    print(
        "CANVIS will extract postage stamps around RA %s DEC %s for field %s on CCD %s for all seed IDs and dates."
        % (ra, dec, field, ccd_num))
    path = '/fred/oz100/pipes/DWF_PIPE/MARY_WORK/' + field + '_*_*_*/ccd' + ccd_num + '/images_resampled/sci_' + ccd_num + '.resamp.fits'
    fitsfileslist = glob.glob(path)
    mydic = SortedDict()
    vmins = []
    vmaxs = []
    for i in fitsfileslist:
        with fits.open(i) as hdu:
            size = 200
            w = WCS(hdu[0].header)
            head = hdu[0].header
            date = dt.datetime.strptime(head['DATE'], '%Y-%m-%dT%H:%M:%S')
            xlim = head['NAXIS1']
            ylim = head['NAXIS2']
            pixcrd_im = np.array([[xlim, ylim]], np.float_)
            world_im = w.wcs_pix2world(pixcrd_im, 1)
            pixx_im, pixy_im = world_im[0][0], world_im[0][1]
            corners = w.calc_footprint()

            corner_1 = corners[0]
            corner_2 = corners[1]
            corner_3 = corners[2]
            corner_4 = corners[3]
            differnce = corner_1 - corner_2

            pixcrd = np.array([[ra, dec]], np.float_)
            worldpix = w.wcs_world2pix(pixcrd, 1)
            pixx, pixy = worldpix[0][0], worldpix[0][1]

            if float(corner_4[0]) <= float(ra) <= float(corner_1[0]) and float(
                    corner_2[1]) >= float(dec) >= float(corner_1[1]):
                path = i
                mydic[date] = [path, pixx, pixy]
                if debugmode:
                    print(mydic)

    for i, (key, (path, pixx, pixy)) in enumerate(mydic.items()):
        path_cand = '/fred/oz100/CANVIS/cand_images/' + run + '/cand_' + format(
            ID, '05') + '_' + field + '_' + run + '/'
        path_cutout = '/fred/oz100/CANVIS/cand_images/' + run + '/cand_' + format(
            ID, '05') + '_' + field + '_' + run + '/cand_' + format(
                ID, '05') + '_' + run + '_cutout_' + format(i, '03')
        if not os.path.exists(path_cand):
            os.makedirs(path_cand, 0o755)
        else:
            pass
        size = 200
        with fits.open(path) as hdu:
            nom_data = (hdu[0].data - np.min(hdu[0].data)) / (
                np.max(hdu[0].data) - np.min(hdu[0].data))
            cutout = Cutout2D(hdu[0].data, (pixx, pixy), size, wcs=w)
            hdu[0].data = cutout.data
            hdu[0].header['CRPIX1'] = cutout.wcs.wcs.crpix[0]
            hdu[0].header['CRPIX2'] = cutout.wcs.wcs.crpix[1]

            interval = astrovis.ZScaleInterval()
            vmin, vmax = interval.get_limits(hdu[0].data)
            vmins.append(vmin)
            vmaxs.append(vmax)

            hdu.writeto(path_cutout + '_CUTOUT.fits', overwrite=True)
            #plt.axis('off')
            #plt.imshow(hdu[0].data, cmap='gray', vmin=vmin, vmax=vmax)
            #plt.colorbar()
            #plt.savefig(path_cutout+'.png', overwite=True)
            #plt.close()

    files = []
    vmins = []
    vmaxs = []
    path_cand = '/fred/oz100/CANVIS/cand_images/' + run + '/cand_' + format(
        ID, '05') + '_' + field + '_' + run + '/'
    average_vmin = np.average(vmins)
    average_vmax = np.average(vmaxs)
    # path_cutout = '/fred/oz100/CANVIS/cand_images/'+ run +'/cand_'+format(ID, '05')+'_'+ field +'_'+ run +'/cand_'+format(ID, '05')+'_'+run+'_cutout_'+format(i, '03')
    length_num = 1

    for cutouts in os.listdir(path_cand):
        length = []
        # length_num = 1

        if cutouts.endswith('.fits'):
            a = 1
            length_num += 1

    print(length_num)
Ejemplo n.º 20
0
    def run(self, maskslits=None, doqa=True, debug=False, show=False):
        """
        Main driver for tracing arc lines

        Code flow:
            1. Extract an arc spectrum down the center of each slit/order
            2. Loop on slits/orders
                i. Trace and fit the arc lines (This is done twice, once
                   with trace_crude as the tracing crutch, then again
                   with a PCA model fit as the crutch).
                ii. Repeat trace.
                iii.  2D Fit to the offset from slitcen
                iv. Save

        Args:
            maskslits (`numpy.ndarray`_, optional):
                Boolean array to ignore slits.
            doqa (bool):
            debug (bool):
            show (bool):

        Returns:
            dict, ndarray:  Tilts dict and maskslits array

        """

        if maskslits is None:
            maskslits = np.zeros(self.nslits, dtype=bool)

        # Extract the arc spectra for all slits
        self.arccen, self.arccen_bpm, self.arc_maskslit = self.extract_arcs()

        # TODO: Leave for now.  Used for debugging
        #        self.par['rm_continuum'] = True
        #        debug = True
        #        show = True

        # Subtract arc continuum
        _msarc = self.msarc.image.copy()
        if self.par['rm_continuum']:
            continuum = self.model_arc_continuum(debug=debug)
            _msarc -= continuum
            if debug:
                # TODO: Put this into a function
                vmin, vmax = visualization.ZScaleInterval().get_limits(_msarc)
                w, h = plt.figaspect(1)
                fig = plt.figure(figsize=(3 * w, h))
                ax = fig.add_axes([0.15 / 3, 0.1, 0.8 / 3, 0.8])
                ax.imshow(self.msarc.image,
                          origin='lower',
                          interpolation='nearest',
                          aspect='auto',
                          vmin=vmin,
                          vmax=vmax)
                ax.set_title('MasterArc')
                ax = fig.add_axes([1.15 / 3, 0.1, 0.8 / 3, 0.8])
                ax.imshow(continuum,
                          origin='lower',
                          interpolation='nearest',
                          aspect='auto',
                          vmin=vmin,
                          vmax=vmax)
                ax.set_title('Continuum')
                ax = fig.add_axes([2.15 / 3, 0.1, 0.8 / 3, 0.8])
                ax.imshow(_msarc,
                          origin='lower',
                          interpolation='nearest',
                          aspect='auto',
                          vmin=vmin,
                          vmax=vmax)
                ax.set_title('MasterArc - Continuum')
                plt.show()

        # maskslit
        self.mask = np.any([maskslits, self.arc_maskslit == 1], axis=0)
        gdslits = np.where(np.invert(self.mask))[0]

        # Final tilts image
        self.final_tilts = np.zeros(self.shape_science, dtype=float)
        max_spat_dim = (np.asarray(self.par['spat_order']) + 1).max()
        max_spec_dim = (np.asarray(self.par['spec_order']) + 1).max()
        self.coeffs = np.zeros((max_spec_dim, max_spat_dim, self.nslits))
        self.spat_order = np.zeros(self.nslits, dtype=int)
        self.spec_order = np.zeros(self.nslits, dtype=int)

        # TODO sort out show methods for debugging
        #if show:
        #    viewer,ch = ginga.show_image(self.msarc*(self.slitmask > -1),chname='tilts')

        # Loop on all slits
        for slit in gdslits:
            msgs.info('Computing tilts for slit {0}/{1}'.format(
                slit, self.nslits - 1))
            # Identify lines for tracing tilts
            msgs.info('Finding lines for tilt analysis')
            self.lines_spec, self.lines_spat \
                    = self.find_lines(self.arccen[:,slit], self.slitcen[:,slit], slit,
                                      bpm=self.arccen_bpm[:,slit], debug=False) #debug)
            if self.lines_spec is None:
                self.mask[slit] = True
                maskslits[slit] = True
                continue

            thismask = self.slitmask == slit

            # Performs the initial tracing of the line centroids as a
            # function of spatial position resulting in 1D traces for
            # each line.
            msgs.info('Trace the tilts')
            self.trace_dict = self.trace_tilts(_msarc, self.lines_spec,
                                               self.lines_spat, thismask,
                                               self.slitcen[:, slit])

            # TODO: Show the traces before running the 2D fit

            #if show:
            #    ginga.show_tilts(viewer, ch, self.trace_dict)

            self.spat_order[slit] = self._parse_param(self.par, 'spat_order',
                                                      slit)
            self.spec_order[slit] = self._parse_param(self.par, 'spec_order',
                                                      slit)
            # 2D model of the tilts, includes construction of QA
            # NOTE: This also fills in self.all_fit_dict and self.all_trace_dict
            coeff_out = self.fit_tilts(self.trace_dict,
                                       thismask,
                                       self.slitcen[:, slit],
                                       self.spat_order[slit],
                                       self.spec_order[slit],
                                       slit,
                                       doqa=doqa,
                                       show_QA=show,
                                       debug=show)
            self.coeffs[:self.spec_order[slit] + 1, :self.spat_order[slit] + 1,
                        slit] = coeff_out

            # Tilts are created with the size of the original slitmask,
            # which corresonds to the same binning as the science
            # images, trace images, and pixelflats etc.
            self.tilts = tracewave.fit2tilts(self.slitmask_science.shape,
                                             coeff_out, self.par['func2d'])
            # Save to final image
            thismask_science = self.slitmask_science == slit
            self.final_tilts[thismask_science] = self.tilts[thismask_science]

        if debug:
            # TODO: Add this to the show method?
            vmin, vmax = visualization.ZScaleInterval().get_limits(_msarc)
            plt.imshow(_msarc,
                       origin='lower',
                       interpolation='nearest',
                       aspect='auto',
                       vmin=vmin,
                       vmax=vmax)
            for slit in gdslits:
                spat = self.all_trace_dict[slit]['tilts_spat']
                spec = self.all_trace_dict[slit]['tilts']
                spec_fit = self.all_trace_dict[slit]['tilts_fit']
                in_fit = self.all_trace_dict[slit]['tot_mask']
                not_fit = np.invert(in_fit) & (spec > 0)
                fit_rej = in_fit & np.invert(
                    self.all_trace_dict[slit]['fit_mask'])
                fit_keep = in_fit & self.all_trace_dict[slit]['fit_mask']
                plt.scatter(spat[not_fit],
                            spec[not_fit],
                            color='C1',
                            marker='.',
                            s=30,
                            lw=0)
                plt.scatter(spat[fit_rej],
                            spec[fit_rej],
                            color='C3',
                            marker='.',
                            s=30,
                            lw=0)
                plt.scatter(spat[fit_keep],
                            spec[fit_keep],
                            color='k',
                            marker='.',
                            s=30,
                            lw=0)
                with_fit = np.invert(np.all(np.invert(fit_keep), axis=0))
                for t in range(in_fit.shape[1]):
                    if not with_fit[t]:
                        continue
                    l, r = np.nonzero(in_fit[:, t])[0][[0, -1]]
                    plt.plot(spat[l:r + 1, t], spec_fit[l:r + 1, t], color='k')
            plt.show()

        self.tilts_dict = {
            'tilts': self.final_tilts,
            'coeffs': self.coeffs,
            'slitcen': self.slitcen,
            'func2d': self.par['func2d'],
            'nslit': self.nslits,
            'spat_order': self.spat_order,
            'spec_order': self.spec_order
        }
        return self.tilts_dict, maskslits
Ejemplo n.º 21
0
import numpy as np
from astropy.io import fits
import astropy.visualization as visualization
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from kernels import convolve
import sys
plt.style.use('mystyle-2.mplstyle')
zscale = visualization.ZScaleInterval()
def show(matrix):
    plt.figure()
    imshow(matrix)
def reject_galaxies_radius(galaxylist,internalradius=8,ignoreborder=150):
    
    print('- Loading Mask')
    # import mask with objects to neglect and define the area of rejection
    maskpath = '../Images/nostar_mask.fits'
    hdulist=fits.open(maskpath)
    mask = hdulist[0].data
    mask=np.where(mask==1,0,1)
    hdulist.close()
    
    
    print('- Loading Image')
    # import mask with objects to neglect and define the area of rejection
    maskpath = 'A1_mosaic_nostar.fits'
    hdulist=fits.open(maskpath)
    image = hdulist[0].data
    hdulist.close()