Example #1
0
def raft_level_oscan_correlations(bias_files, buffer=10, title='',
                                  vrange=None, stretch=viz.LinearStretch):
    """
    Compute the correlation coefficients between the overscan pixels
    of the 144 amplifiers in raft.

    Parameters
    ----------
    bias_files: dict
        Dictionary of bias image files, indexed by sensor slot id.
    buffer: int [10]
        Buffer region around perimeter of serial overscan region to
        avoid when computing the correlation coefficients.
    title: str ['']
        Plot title.
    vrange: (float, float) [None]
        Minimum and maximum values for color scale range. If None, then
        the range of the central 98th percentile of the absolute value
        of the data is used.
    stretch: astropy.visualization.BaseStretch [LinearStretch]
        Stretch to use for the color scale.

    Returns
    -------
    (matplotlib.figure.Figure, np.array): The figure containing the plot and
        the numpy array containing the correlation coefficients.
    """
    slots = 'S00 S01 S02 S10 S11 S12 S20 S21 S22'.split()
    bbox = None
    overscans = []
    for slot in slots:
        ccd = sensorTest.MaskedCCD(bias_files[slot])
        if bbox is None:
            bbox = ccd.amp_geom.serial_overscan
            bbox.grow(-buffer)
        for amp in ccd:
            image = ccd[amp].getImage()
            overscans.append(image.Factory(image, bbox).getArray())
    namps = len(overscans)
    data = np.array([np.corrcoef(overscans[i[0]].ravel(),
                                 overscans[i[1]].ravel())[0, 1]
                     for i in itertools.product(range(namps), range(namps))])
    data = data.reshape((namps, namps))
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_title(title, fontsize='medium')

    interval = viz.PercentileInterval(98.)
    if vrange is None:
        vrange = interval.get_limits(np.abs(data.flatten()))
    norm = ImageNormalize(vmin=vrange[0], vmax=vrange[1], stretch=stretch())
    image = ax.imshow(data, interpolation='none', norm=norm)
    plt.colorbar(image)

    set_ticks(ax, slots, amps=16)

    return fig, data
Example #2
0
    def plot_norm(self,
                  stretch='linear',
                  power=1.0,
                  asinh_a=0.1,
                  min_cut=None,
                  max_cut=None,
                  min_percent=None,
                  max_percent=None,
                  percent=None,
                  clip=True):
        """Create a matplotlib norm object for plotting.

        This is a copy of this function that will be available in Astropy 1.3:
        `astropy.visualization.mpl_normalize.simple_norm`

        See the parameter description there!

        Examples
        --------
        >>> image = SkyImage()
        >>> norm = image.plot_norm(stretch='sqrt', max_percent=99)
        >>> image.plot(norm=norm)
        """
        import astropy.visualization as v
        from astropy.visualization.mpl_normalize import ImageNormalize

        if percent is not None:
            interval = v.PercentileInterval(percent)
        elif min_percent is not None or max_percent is not None:
            interval = v.AsymmetricPercentileInterval(min_percent or 0.,
                                                      max_percent or 100.)
        elif min_cut is not None or max_cut is not None:
            interval = v.ManualInterval(min_cut, max_cut)
        else:
            interval = v.MinMaxInterval()

        if stretch == 'linear':
            stretch = v.LinearStretch()
        elif stretch == 'sqrt':
            stretch = v.SqrtStretch()
        elif stretch == 'power':
            stretch = v.PowerStretch(power)
        elif stretch == 'log':
            stretch = v.LogStretch()
        elif stretch == 'asinh':
            stretch = v.AsinhStretch(asinh_a)
        else:
            raise ValueError('Unknown stretch: {0}.'.format(stretch))

        vmin, vmax = interval.get_limits(self.data)

        return ImageNormalize(vmin=vmin, vmax=vmax, stretch=stretch, clip=clip)
Example #3
0
def select_cutout(image, wcs):

    #I'm looking for how many pointings are in the mosaic
    #I don't know if it is always accurate
    nrow = image.shape[0] // 300.
    ncol = image.shape[1] // 300.
    #measuring the exact width of a row and of a column
    drow = image.shape[0] / nrow
    dcol = image.shape[1] / ncol
    #I'm showing the image to select the correct section
    #I'm picking the center with a mouse click (maybe)

    fig, ax = plt.subplots(1, 1)
    interval = vis.PercentileInterval(99.9)
    vmin, vmax = interval.get_limits(image)
    norm = vis.ImageNormalize(vmin=vmin,
                              vmax=vmax,
                              stretch=vis.LogStretch(1000))
    ax.imshow(image, cmap=plt.cm.Greys, norm=norm, origin='lower')
    for x in np.arange(0, image.shape[1], dcol):
        ax.axvline(x)
    for y in np.arange(0, image.shape[0], drow):
        ax.axhline(y)

    def onclick(event):

        ix, iy = event.xdata, event.ydata
        col = ix // 300.
        row = iy // 300.
        print(col, row)
        global x_cen, y_cen
        x_cen = 150 + 300 * (col)  #x of the center of the quadrans
        y_cen = 150 + 300 * (row)  #y of the center of thw quadrans
        print('x: {:3.0f}, y: {:3.0f}'.format(x_cen, y_cen))
        if event.key == 'q':
            fig.canvas.mpl_disconnect(cid)

    cid = fig.canvas.mpl_connect('button_press_event', onclick)
    plt.show()

    nrow = image.shape[0] // 300.
    ncol = image.shape[1] // 300.
    print(image.shape[0] / nrow)
    x = int(x_cen)
    y = int(y_cen)
    print(x, y)
    cutout = Cutout2D(image, (x, y),
                      size=(image.shape[0] / nrow - 20) * u.pixel,
                      wcs=wcs)
    return cutout
Example #4
0
def make_analysis_forms(
        basepath="/orange/adamginsburg/web/secure/ALMA-IMF/October31Release/",
        base_form_url="https://docs.google.com/forms/d/e/1FAIpQLSczsBdB3Am4znOio2Ky5GZqAnRYDrYTD704gspNu7fAMm2-NQ/viewform?embedded=true",
        dontskip_noresid=False):
    import glob
    from diagnostic_images import load_images, show as show_images
    from astropy import visualization
    import pylab as pl

    savepath = f'{basepath}/quicklooks'

    try:
        os.mkdir(savepath)
    except:
        pass

    filedict = {
        (field, band, config, robust, selfcal): glob.glob(
            f"{field}/B{band}/{imtype}{field}*_B{band}_*_{config}_robust{robust}*selfcal{selfcal}*.image.tt0*.fits"
        )
        for field in
        "G008.67 G337.92 W43-MM3 G328.25 G351.77 G012.80 G327.29 W43-MM1 G010.62 W51-IRS2 W43-MM2 G333.60 G338.93 W51-E G353.41"
        .split() for band in (3, 6)
        #for config in ('7M12M', '12M')
        for config in ('12M', )
        #for robust in (-2, 0, 2)
        for robust in (0, ) for selfcal in ("", ) + tuple(range(0, 9))
        for imtype in (('', ) if 'October31' in basepath else ('cleanest/',
                                                               'bsens/'))
    }
    badfiledict = {key: val for key, val in filedict.items() if len(val) == 1}
    print(f"Bad files: {badfiledict}")
    filedict = {key: val for key, val in filedict.items() if len(val) > 1}
    filelist = [key + (fn, ) for key, val in filedict.items() for fn in val]

    prev = 'index.html'

    flist = []

    #for field in "G008.67 G337.92 W43-MM3 G328.25 G351.77 G012.80 G327.29 W43-MM1 G010.62 W51-IRS2 W43-MM2 G333.60 G338.93 W51-E G353.41".split():
    ##for field in ("G333.60",):
    #    for band in (3,6):
    #        for config in ('7M12M', '12M'):
    #            for robust in (-2, 0, 2):

    #                # for not all-in-the-same-place stuff
    #                fns = [x for x in glob.glob(f"{field}/B{band}/{field}*_B{band}_*_{config}_robust{robust}*selfcal[0-9]*.image.tt0*.fits") ]

    #                for fn in fns:
    for ii, (field, band, config, robust, selfcal, fn) in enumerate(filelist):

        image = fn
        basename, suffix = image.split(".image.tt0")
        if 'diff' in suffix or 'bsens-cleanest' in suffix:
            continue
        outname = basename.split("/")[-1]

        if prev == outname + ".html":
            print(
                f"{ii}: {(field, band, config, robust, fn)} yielded the same prev "
                f"{prev} as last time, skipping.")
            continue

        jj = 1
        while jj < len(filelist):
            if ii + jj < len(filelist):
                next_ = filelist[ii + jj][5].split(".image.tt0")[0].split(
                    "/")[-1] + ".html"
            else:
                next_ = "index.html"

            if next_ == outname + ".html":
                jj = jj + 1
            else:
                break

        assert next_ != outname + ".html"

        try:
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore')
                print(f"{ii}: {(field, band, config, robust, fn, selfcal)}"
                      f" basename='{basename}', suffix='{suffix}'")
                imgs, cubes = load_images(basename, suffix=suffix)
        except KeyError as ex:
            print(ex)
            raise
        except Exception as ex:
            print(f"EXCEPTION: {type(ex)}: {str(ex)}")
            raise
            continue
        norm = visualization.ImageNormalize(
            stretch=visualization.AsinhStretch(),
            interval=visualization.PercentileInterval(99.95))
        # set the scaling based on one of these...
        # (this call inplace-modifies logn, according to the docs)
        if 'residual' in imgs:
            norm(imgs['residual'][imgs['residual'] == imgs['residual']])
            imnames_toplot = ('mask', 'model', 'image', 'residual')
        elif 'image' in imgs and dontskip_noresid:
            imnames_toplot = (
                'image',
                'mask',
            )
            norm(imgs['image'][imgs['image'] == imgs['image']])
        else:
            print(
                f"Skipped {fn} because no image OR residual was found.  imgs.keys={imgs.keys()}"
            )
            continue
        pl.close(1)
        pl.figure(1, figsize=(14, 6))
        show_images(imgs, norm=norm, imnames_toplot=imnames_toplot)

        pl.savefig(f"{savepath}/{outname}.png", dpi=150, bbox_inches='tight')

        metadata = {
            'field': field,
            'band': band,
            'selfcal': selfcal,  #get_selfcal_number(basename),
            'array': config,
            'robust': robust,
            'finaliter': 'finaliter' in fn,
        }
        make_quicklook_analysis_form(filename=outname,
                                     metadata=metadata,
                                     savepath=savepath,
                                     prev=prev,
                                     next_=next_,
                                     base_form_url=base_form_url)
        metadata['outname'] = outname
        metadata['suffix'] = suffix
        if robust == 0:
            # only keep robust=0 for simplicity
            flist.append(metadata)
        prev = outname + ".html"

    #make_rand_html(savepath)
    make_index(savepath, flist)

    return flist
Example #5
0
def make_sed_plot(coordinate,
                  mgpsfile,
                  width=1 * u.arcmin,
                  surveys=Magpis.list_surveys(),
                  figure=None,
                  regname='GAL_031'):

    mgps_fh = fits.open(mgpsfile)[0]
    frame = wcs.utils.wcs_to_celestial_frame(wcs.WCS(mgps_fh.header))

    coordname = "{0:06.3f}_{1:06.3f}".format(coordinate.galactic.l.deg,
                                             coordinate.galactic.b.deg)

    mgps_cutout = Cutout2D(mgps_fh.data,
                           coordinate.transform_to(frame.name),
                           size=width * 2,
                           wcs=wcs.WCS(mgps_fh.header))
    print(
        f"Retrieving MAGPIS data for {coordname} ({coordinate.to_string()} {coordinate.frame.name})"
    )
    # we're treating 'width' as a radius elsewhere, here it's a full width
    images = {
        survey: getimg(coordinate, image_size=width * 2.75, survey=survey)
        for survey in surveys
    }
    images = {x: y for x, y in images.items() if y is not None}
    images['mgps'] = [mgps_cutout]

    regdir = os.path.join(paths.basepath, regname)
    if not os.path.exists(regdir):
        os.mkdir(regdir)
    higaldir = os.path.join(paths.basepath, regname, 'HiGalCutouts')
    if not os.path.exists(higaldir):
        os.mkdir(higaldir)
    if not any([
            os.path.exists(f"{higaldir}/{coordname}_{wavelength}.fits")
            for wavelength in map(int, HiGal.HIGAL_WAVELENGTHS.values())
    ]):
        print(
            f"Retrieving HiGal data for {coordname} ({coordinate.to_string()} {coordinate.frame.name})"
        )
        higal_ims = HiGal.get_images(coordinate, radius=width * 1.5)
        for hgim in higal_ims:
            images['HiGal{0}'.format(hgim[0].header['WAVELEN'])] = hgim
            hgim.writeto(
                f"{higaldir}/{coordname}_{hgim[0].header['WAVELEN']}.fits")
    else:
        print(
            f"Loading HiGal data from disk for {coordname} ({coordinate.to_string()} {coordinate.frame.name})"
        )
        for wavelength in map(int, HiGal.HIGAL_WAVELENGTHS.values()):
            hgfn = f"{higaldir}/{coordname}_{wavelength}.fits"
            if os.path.exists(hgfn):
                hgim = fits.open(hgfn)
                images['HiGal{0}'.format(hgim[0].header['WAVELEN'])] = hgim

    if 'gpsmsx2' in images:
        # redundant, save some space for a SED plot
        del images['gpsmsx2']
    if 'gps90' in images:
        # too low-res to be useful
        del images['gps90']

    if figure is None:
        figure = pl.figure(figsize=(15, 12))

    # coordinate stuff so images can be reprojected to same frame
    ww = mgps_cutout.wcs.celestial
    target_header = ww.to_header()
    del target_header['LONPOLE']
    del target_header['LATPOLE']
    mgps_pixscale = (wcs.utils.proj_plane_pixel_area(ww) * u.deg**2)**0.5
    target_header['NAXES'] = 2
    target_header['NAXIS1'] = target_header['NAXIS2'] = (
        width / mgps_pixscale).decompose().value
    #shape = [int((width / mgps_pixscale).decompose().value)]*2
    outframe = wcs.utils.wcs_to_celestial_frame(ww)
    crd_outframe = coordinate.transform_to(outframe)

    figure.clf()

    imagelist = sorted(images.items(), key=lambda x: wlmap[x[0]])

    #for ii, (survey,img) in enumerate(images.items()):
    for ii, (survey, img) in enumerate(imagelist):

        if hasattr(img[0], 'header'):
            inwcs = wcs.WCS(img[0].header).celestial
            pixscale_in = (wcs.utils.proj_plane_pixel_area(inwcs) *
                           u.deg**2)**0.5

            target_header['CDELT1'] = -pixscale_in.value
            target_header['CDELT2'] = pixscale_in.value
            target_header['CRVAL1'] = crd_outframe.spherical.lon.deg
            target_header['CRVAL2'] = crd_outframe.spherical.lat.deg
            axsize = int((width * 2.5 / pixscale_in).decompose().value)
            target_header['NAXIS1'] = target_header['NAXIS2'] = axsize
            target_header['CRPIX1'] = target_header['NAXIS1'] / 2
            target_header['CRPIX2'] = target_header['NAXIS2'] / 2
            shape_out = [axsize, axsize]

            print(
                f"Reprojecting {survey} to scale {pixscale_in} with shape {shape_out} and center {crd_outframe.to_string()}"
            )

            outwcs = wcs.WCS(target_header)

            new_img, _ = reproject.reproject_interp((img[0].data, inwcs),
                                                    target_header,
                                                    shape_out=shape_out)
        else:
            new_img = img[0].data
            outwcs = img[0].wcs
            pixscale_in = (wcs.utils.proj_plane_pixel_area(outwcs) *
                           u.deg**2)**0.5

        ax = figure.add_subplot(4, 5, ii + 1, projection=outwcs)
        ax.set_title("{0}: {1}".format(survey_titles[survey], wlmap[survey]))

        if not np.any(np.isfinite(new_img)):
            print(f"SKIPPING {survey}")
            continue

        norm = visualization.ImageNormalize(
            new_img,
            interval=visualization.PercentileInterval(99.95),
            stretch=visualization.AsinhStretch(),
        )

        ax.imshow(new_img, origin='lower', interpolation='none', norm=norm)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.xaxis.set_ticklabels('')
        ax.yaxis.set_ticklabels('')
        ax.coords[0].set_ticklabel_visible(False)
        ax.coords[1].set_ticklabel_visible(False)

        if 'GLON' in outwcs.wcs.ctype[0]:
            xpix, ypix = outwcs.wcs_world2pix(coordinate.galactic.l,
                                              coordinate.galactic.b, 0)
        else:
            xpix, ypix = outwcs.wcs_world2pix(coordinate.fk5.ra,
                                              coordinate.fk5.dec, 0)
        ax.set_xlim(xpix - (width / pixscale_in), xpix + (width / pixscale_in))
        ax.set_ylim(ypix - (width / pixscale_in), ypix + (width / pixscale_in))

        # scalebar = 1 arcmin

        ax.plot([
            xpix - width / pixscale_in + 5 * u.arcsec / pixscale_in,
            xpix - width / pixscale_in + 65 * u.arcsec / pixscale_in
        ], [
            ypix - width / pixscale_in + 5 * u.arcsec / pixscale_in,
            ypix - width / pixscale_in + 5 * u.arcsec / pixscale_in
        ],
                linestyle='-',
                linewidth=1,
                color='w')
        ax.plot(crd_outframe.spherical.lon.deg,
                crd_outframe.spherical.lat.deg,
                marker=((0, -10), (0, -4)),
                color='w',
                linestyle='none',
                markersize=20,
                markeredgewidth=0.5,
                transform=ax.get_transform('world'))
        ax.plot(crd_outframe.spherical.lon.deg,
                crd_outframe.spherical.lat.deg,
                marker=((4, 0), (10, 0)),
                color='w',
                linestyle='none',
                markersize=20,
                markeredgewidth=0.5,
                transform=ax.get_transform('world'))

    pl.tight_layout()
    def makeSlitIllum(self, adinputs=None, **params):
        """
        Makes the processed Slit Illumination Function by binning a 2D
        spectrum along the dispersion direction, fitting a smooth function
        for each bin, fitting a smooth 2D model, and reconstructing the 2D
        array using this last model.

        Its implementation based on the IRAF's `noao.twodspec.longslit.illumination`
        task following the algorithm described in [Valdes, 1968].

        It expects an input calibration image to be an a dispersed image of the
        slit without illumination problems (e.g, twilight flat). The spectra is
        not required to be smooth in wavelength and may contain strong emission
        and absorption lines. The image should contain a `.mask` attribute in
        each extension, and it is expected to be overscan and bias corrected.

        Parameters
        ----------
        adinputs : list
            List of AstroData objects containing the dispersed image of the
            slit of a source free of illumination problems. The data needs to
            have been overscan and bias corrected and is expected to have a
            Data Quality mask.
        bins : {None, int}, optional
            Total number of bins across the dispersion axis. If None,
            the number of bins will match the number of extensions on each
            input AstroData object. It it is an int, it will create N bins
            with the same size.
        border : int, optional
            Border size that is added on every edge of the slit illumination
            image before cutting it down to the input AstroData frame.
        smooth_order : int, optional
            Order of the spline that is used in each bin fitting to smooth
            the data (Default: 3)
        x_order : int, optional
            Order of the x-component in the Chebyshev2D model used to
            reconstruct the 2D data from the binned data.
        y_order : int, optional
            Order of the y-component in the Chebyshev2D model used to
            reconstruct the 2D data from the binned data.

        Return
        ------
        List of AstroData : containing an AstroData with the Slit Illumination
            Response Function for each of the input object.

        References
        ----------
        .. [Valdes, 1968] Francisco Valdes "Reduction Of Long Slit Spectra With
           IRAF", Proc. SPIE 0627, Instrumentation in Astronomy VI,
           (13 October 1986); https://doi.org/10.1117/12.968155
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]

        suffix = params["suffix"]
        bins = params["bins"]
        border = params["border"]
        debug_plot = params["debug_plot"]
        smooth_order = params["smooth_order"]
        cheb2d_x_order = params["x_order"]
        cheb2d_y_order = params["y_order"]

        ad_outputs = []
        for ad in adinputs:

            if len(ad) > 1 and "mosaic" not in ad[0].wcs.available_frames:

                log.info('Add "mosaic" gWCS frame to input data')
                geotable = import_module('.geometry_conf', self.inst_lookups)

                # deepcopy prevents modifying input `ad` inplace
                ad = transform.add_mosaic_wcs(deepcopy(ad), geotable)

                log.info("Temporarily mosaicking multi-extension file")
                mosaicked_ad = transform.resample_from_wcs(
                    ad,
                    "mosaic",
                    attributes=None,
                    order=1,
                    process_objcat=False)

            else:

                log.info('Input data already has one extension and has a '
                         '"mosaic" frame.')

                # deepcopy prevents modifying input `ad` inplace
                mosaicked_ad = deepcopy(ad)

            log.info("Transposing data if needed")
            dispaxis = 2 - mosaicked_ad[0].dispersion_axis()  # python sense
            should_transpose = dispaxis == 1

            data, mask, variance = _transpose_if_needed(
                mosaicked_ad[0].data,
                mosaicked_ad[0].mask,
                mosaicked_ad[0].variance,
                transpose=should_transpose)

            log.info("Masking data")
            data = np.ma.masked_array(data, mask=mask)
            variance = np.ma.masked_array(variance, mask=mask)
            std = np.sqrt(variance)  # Easier to work with

            log.info("Creating bins for data and variance")
            height = data.shape[0]
            width = data.shape[1]

            if bins is None:
                nbins = max(len(ad), 12)
                bin_limits = np.linspace(0, height, nbins + 1, dtype=int)
            elif isinstance(bins, int):
                nbins = bins
                bin_limits = np.linspace(0, height, nbins + 1, dtype=int)
            else:
                # ToDo: Handle input bins as array
                raise TypeError("Expected None or Int for `bins`. "
                                "Found: {}".format(type(bins)))

            bin_top = bin_limits[1:]
            bin_bot = bin_limits[:-1]
            binned_data = np.zeros_like(data)
            binned_std = np.zeros_like(std)

            log.info("Smooth binned data and variance, and normalize them by "
                     "smoothed central value")
            for bin_idx, (b0, b1) in enumerate(zip(bin_bot, bin_top)):

                rows = np.arange(width)

                avg_data = np.ma.mean(data[b0:b1], axis=0)
                model_1d_data = astromodels.UnivariateSplineWithOutlierRemoval(
                    rows, avg_data, order=smooth_order)

                avg_std = np.ma.mean(std[b0:b1], axis=0)
                model_1d_std = astromodels.UnivariateSplineWithOutlierRemoval(
                    rows, avg_std, order=smooth_order)

                slit_central_value = model_1d_data(rows)[width // 2]
                binned_data[b0:b1] = model_1d_data(rows) / slit_central_value
                binned_std[b0:b1] = model_1d_std(rows) / slit_central_value

            log.info("Reconstruct 2D mosaicked data")
            bin_center = np.array(0.5 * (bin_bot + bin_top), dtype=int)
            cols_fit, rows_fit = np.meshgrid(np.arange(width), bin_center)

            fitter = fitting.SLSQPLSQFitter()
            model_2d_init = models.Chebyshev2D(x_degree=cheb2d_x_order,
                                               x_domain=(0, width),
                                               y_degree=cheb2d_y_order,
                                               y_domain=(0, height))

            model_2d_data = fitter(model_2d_init, cols_fit, rows_fit,
                                   binned_data[rows_fit, cols_fit])

            model_2d_std = fitter(model_2d_init, cols_fit, rows_fit,
                                  binned_std[rows_fit, cols_fit])

            rows_val, cols_val = \
                np.mgrid[-border:height+border, -border:width+border]

            slit_response_data = model_2d_data(cols_val, rows_val)
            slit_response_mask = np.pad(
                mask, border, mode='edge')  # ToDo: any update to the mask?
            slit_response_std = model_2d_std(cols_val, rows_val)
            slit_response_var = slit_response_std**2

            del cols_fit, cols_val, rows_fit, rows_val

            _data, _mask, _variance = _transpose_if_needed(
                slit_response_data,
                slit_response_mask,
                slit_response_var,
                transpose=dispaxis == 1)

            log.info("Update slit response data and data_section")
            slit_response_ad = deepcopy(mosaicked_ad)
            slit_response_ad[0].data = _data
            slit_response_ad[0].mask = _mask
            slit_response_ad[0].variance = _variance

            if "mosaic" in ad[0].wcs.available_frames:

                log.info(
                    "Map coordinates between slit function and mosaicked data"
                )  # ToDo: Improve message?
                slit_response_ad = _split_mosaic_into_extensions(
                    ad, slit_response_ad, border_size=border)

            elif len(ad) == 1:

                log.info("Trim out borders")

                slit_response_ad[0].data = \
                    slit_response_ad[0].data[border:-border, border:-border]
                slit_response_ad[0].mask = \
                    slit_response_ad[0].mask[border:-border, border:-border]
                slit_response_ad[0].variance = \
                    slit_response_ad[0].variance[border:-border, border:-border]

            log.info("Update metadata and filename")
            gt.mark_history(slit_response_ad,
                            primname=self.myself(),
                            keyword=timestamp_key)

            slit_response_ad.update_filename(suffix=suffix, strip=True)
            ad_outputs.append(slit_response_ad)

            # Plotting ------
            if debug_plot:

                log.info("Creating plots")
                palette = copy(plt.cm.cividis)
                palette.set_bad('r', 0.75)

                norm = vis.ImageNormalize(data[~data.mask],
                                          stretch=vis.LinearStretch(),
                                          interval=vis.PercentileInterval(97))

                fig = plt.figure(num="Slit Response from MEF - {}".format(
                    ad.filename),
                                 figsize=(12, 9),
                                 dpi=110)

                gs = gridspec.GridSpec(nrows=2, ncols=3, figure=fig)

                # Display raw mosaicked data and its bins ---
                ax1 = fig.add_subplot(gs[0, 0])
                im1 = ax1.imshow(data,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=norm.vmin,
                                 vmax=norm.vmax)

                ax1.set_title("Mosaicked Data\n and Spectral Bins",
                              fontsize=10)
                ax1.set_xlim(-1, data.shape[1])
                ax1.set_xticks([])
                ax1.set_ylim(-1, data.shape[0])
                ax1.set_yticks(bin_center)
                ax1.tick_params(axis=u'both', which=u'both', length=0)

                ax1.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax1.spines[s].set_visible(False) for s in ax1.spines]
                _ = [ax1.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax1)
                cax1 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im1, cax=cax1)

                # Display non-smoothed bins ---
                ax2 = fig.add_subplot(gs[0, 1])
                im2 = ax2.imshow(binned_data, cmap=palette, origin='lower')

                ax2.set_title("Binned, smoothed\n and normalized data ",
                              fontsize=10)
                ax2.set_xlim(0, data.shape[1])
                ax2.set_xticks([])
                ax2.set_ylim(0, data.shape[0])
                ax2.set_yticks(bin_center)
                ax2.tick_params(axis=u'both', which=u'both', length=0)

                ax2.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax2.spines[s].set_visible(False) for s in ax2.spines]
                _ = [ax2.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax2)
                cax2 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im2, cax=cax2)

                # Display reconstructed slit response ---
                vmin = slit_response_data.min()
                vmax = slit_response_data.max()

                ax3 = fig.add_subplot(gs[1, 0])
                im3 = ax3.imshow(slit_response_data,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=vmin,
                                 vmax=vmax)

                ax3.set_title("Reconstructed\n Slit response", fontsize=10)
                ax3.set_xlim(0, data.shape[1])
                ax3.set_xticks([])
                ax3.set_ylim(0, data.shape[0])
                ax3.set_yticks([])
                ax3.tick_params(axis=u'both', which=u'both', length=0)
                _ = [ax3.spines[s].set_visible(False) for s in ax3.spines]

                divider = make_axes_locatable(ax3)
                cax3 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im3, cax=cax3)

                # Display extensions ---
                ax4 = fig.add_subplot(gs[1, 1])
                ax4.set_xticks([])
                ax4.set_yticks([])
                _ = [ax4.spines[s].set_visible(False) for s in ax4.spines]

                sub_gs4 = gridspec.GridSpecFromSubplotSpec(nrows=len(ad),
                                                           ncols=1,
                                                           subplot_spec=gs[1,
                                                                           1],
                                                           hspace=0.03)

                # The [::-1] is needed to put the fist extension in the bottom
                for i, ext in enumerate(slit_response_ad[::-1]):

                    ext_data, ext_mask, ext_variance = _transpose_if_needed(
                        ext.data,
                        ext.mask,
                        ext.variance,
                        transpose=dispaxis == 1)

                    ext_data = np.ma.masked_array(ext_data, mask=ext_mask)

                    sub_ax = fig.add_subplot(sub_gs4[i])

                    im4 = sub_ax.imshow(ext_data,
                                        origin="lower",
                                        vmin=vmin,
                                        vmax=vmax,
                                        cmap=palette)

                    sub_ax.set_xlim(0, ext_data.shape[1])
                    sub_ax.set_xticks([])
                    sub_ax.set_ylim(0, ext_data.shape[0])
                    sub_ax.set_yticks([ext_data.shape[0] // 2])

                    sub_ax.set_yticklabels(
                        ["Ext {}".format(len(slit_response_ad) - i - 1)],
                        fontsize=6)

                    _ = [
                        sub_ax.spines[s].set_visible(False)
                        for s in sub_ax.spines
                    ]

                    if i == 0:
                        sub_ax.set_title(
                            "Multi-extension\n Slit Response Function")

                divider = make_axes_locatable(ax4)
                cax4 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im4, cax=cax4)

                # Display Signal-To-Noise Ratio ---
                snr = data / np.sqrt(variance)

                norm = vis.ImageNormalize(snr[~snr.mask],
                                          stretch=vis.LinearStretch(),
                                          interval=vis.PercentileInterval(97))

                ax5 = fig.add_subplot(gs[0, 2])

                im5 = ax5.imshow(snr,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=norm.vmin,
                                 vmax=norm.vmax)

                ax5.set_title("Mosaicked Data SNR", fontsize=10)
                ax5.set_xlim(-1, data.shape[1])
                ax5.set_xticks([])
                ax5.set_ylim(-1, data.shape[0])
                ax5.set_yticks(bin_center)
                ax5.tick_params(axis=u'both', which=u'both', length=0)

                ax5.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax5.spines[s].set_visible(False) for s in ax5.spines]
                _ = [ax5.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax5)
                cax5 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im5, cax=cax5)

                # Display Signal-To-Noise Ratio of Slit Illumination ---
                slit_response_snr = np.ma.masked_array(
                    slit_response_data / np.sqrt(slit_response_var),
                    mask=slit_response_mask)

                ax6 = fig.add_subplot(gs[1, 2])

                im6 = ax6.imshow(slit_response_snr,
                                 origin="lower",
                                 vmin=norm.vmin,
                                 vmax=norm.vmax,
                                 cmap=palette)

                ax6.set_xlim(0, slit_response_snr.shape[1])
                ax6.set_xticks([])
                ax6.set_ylim(0, slit_response_snr.shape[0])
                ax6.set_yticks([])
                ax6.set_title("Reconstructed\n Slit Response SNR")

                _ = [ax6.spines[s].set_visible(False) for s in ax6.spines]

                divider = make_axes_locatable(ax6)
                cax6 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im6, cax=cax6)

                # Save plots ---
                fig.tight_layout(rect=[0, 0, 0.95, 1], pad=0.5)
                fname = slit_response_ad.filename.replace(".fits", ".png")
                log.info("Saving plots to {}".format(fname))
                plt.savefig(fname)

        return ad_outputs
Example #7
0
def find_bar_positions_from_image(imagefile, filtersize=5, plot=False,
                                  pixel_shim=5):
    '''Loop over all slits in the image and using the affine transformation
    determined by `fit_transforms`, select the Y pixel range over which this
    slit should be found.  Take a median filtered version of that image and
    determine the X direction gradient (derivative).  Then collapse it in
    the Y direction to form a 1D profile.
    
    Using the `find_bar_edges` method, determine the X pixel positions of
    each bar forming the slit.
    
    Convert those X pixel position to physical coordinates using the
    `pixel_to_physical` method and then call the `compare_to_csu_bar_state`
    method to determine the bar state.
    '''
    ## Get image from file
    imagefile = Path(imagefile).absolute()
    try:
        hdul = fits.open(imagefile)
        data = hdul[0].data
    except Error as e:
        log.error(e)
        raise
    # median X pixels only (preserve Y structure)
    medimage = ndimage.median_filter(data, size=(1, filtersize))
    
    bars = {}
    ypos = {}
    for slit in range(1,47):
        b1, b2 = slit_to_bars(slit)
        ## Determine y pixel range
        y1 = int(np.ceil((physical_to_pixel(np.array([(4.0, slit+0.5)])))[0][0][1])) + pixel_shim
        y2 = int(np.floor((physical_to_pixel(np.array([(270.4, slit-0.5)])))[0][0][1])) - pixel_shim
        ypos[b1] = [y1, y2]
        ypos[b2] = [y1, y2]
        gradx = np.gradient(medimage[y1:y2,:], axis=1)
        horizontal_profile = np.sum(gradx, axis=0)
        try:
            bars[b1], bars[b2] = find_bar_edges(horizontal_profile)
        except:
            print(f'Unable to fit bars: {b1}, {b2}')

    # Generate plot if called for
    if plot is True:
        plotfile = imagefile.with_name(f"{imagefile.stem}.png")
        log.info(f'Creating PNG image {plotfile}')
        if plotfile.exists(): plotfile.unlink()
        plt.figure(figsize=(16,16), dpi=300)
        norm = viz.ImageNormalize(data, interval=viz.PercentileInterval(99.9),
                                  stretch=viz.LinearStretch())
        plt.imshow(data, norm=norm, origin='lower', cmap='Greys')


        for bar in bars.keys():
#             plt.plot([0,2048], [ypos[bar][0], ypos[bar][0]], 'r-', alpha=0.1)
#             plt.plot([0,2048], [ypos[bar][1], ypos[bar][1]], 'r-', alpha=0.1)

            mms = np.linspace(4,270.4,2)
            slit = bar_to_slit(bar)
            pix = np.array([(physical_to_pixel(np.array([(mm, slit+0.5)])))[0][0] for mm in mms])
            plt.plot(pix.transpose()[0], pix.transpose()[1], 'g-', alpha=0.5)

            plt.plot([bars[bar],bars[bar]], ypos[bar], 'r-', alpha=0.75)
            offset = {0: -20, 1:+20}[bar % 2]
            plt.text(bars[bar]+offset, np.mean(ypos[bar]), bar,
                     fontsize=8, color='r', alpha=0.75,
                     horizontalalignment='center', verticalalignment='center')
        plt.savefig(str(plotfile), bbox_inches='tight')

    return bars
Example #8
0
def plot_image(image,
               ax=None,
               scale='log',
               cmap=None,
               origin='lower',
               xlabel=None,
               ylabel=None,
               cbar=None,
               clabel='Flux ($e^{-}s^{-1}$)',
               cbar_ticks=None,
               cbar_ticklabels=None,
               cbar_pad=None,
               cbar_size='5%',
               title=None,
               percentile=95.0,
               vmin=None,
               vmax=None,
               offset_axes=None,
               color_bad='k',
               **kwargs):
    """
	Utility function to plot a 2D image.

	Parameters:
		image (2d array): Image data.
		ax (matplotlib.pyplot.axes, optional): Axes in which to plot.
			Default (None) is to use current active axes.
		scale (str or :py:class:`astropy.visualization.ImageNormalize` object, optional):
			Normalization used to stretch the colormap.
			Options: ``'linear'``, ``'sqrt'``, ``'log'``, ``'asinh'``, ``'histeq'``, ``'sinh'``
			and ``'squared'``.
			Can also be a :py:class:`astropy.visualization.ImageNormalize` object.
			Default is ``'log'``.
		origin (str, optional): The origin of the coordinate system.
		xlabel (str, optional): Label for the x-axis.
		ylabel (str, optional): Label for the y-axis.
		cbar (string, optional): Location of color bar.
			Choises are ``'right'``, ``'left'``, ``'top'``, ``'bottom'``.
			Default is not to create colorbar.
		clabel (str, optional): Label for the color bar.
		cbar_size (float, optional): Fractional size of colorbar compared to axes. Default=0.03.
		cbar_pad (float, optional): Padding between axes and colorbar.
		title (str or None, optional): Title for the plot.
		percentile (float, optional): The fraction of pixels to keep in color-trim.
			If single float given, the same fraction of pixels is eliminated from both ends.
			If tuple of two floats is given, the two are used as the percentiles.
			Default=95.
		cmap (matplotlib colormap, optional): Colormap to use. Default is the ``Blues`` colormap.
		vmin (float, optional): Lower limit to use for colormap.
		vmax (float, optional): Upper limit to use for colormap.
		color_bad (str, optional): Color to apply to bad pixels (NaN). Default is black.
		kwargs (dict, optional): Keyword arguments to be passed to :py:func:`matplotlib.pyplot.imshow`.

	Returns:
		:py:class:`matplotlib.image.AxesImage`: Image from returned
			by :py:func:`matplotlib.pyplot.imshow`.

	.. codeauthor:: Rasmus Handberg <*****@*****.**>
	"""

    logger = logging.getLogger(__name__)

    # Backward compatible settings:
    make_cbar = kwargs.pop('make_cbar', None)
    if make_cbar:
        raise FutureWarning("'make_cbar' is deprecated. Use 'cbar' instead.")
        if not cbar:
            cbar = make_cbar

    # Special treatment for boolean arrays:
    if isinstance(image, np.ndarray) and image.dtype == 'bool':
        if vmin is None: vmin = 0
        if vmax is None: vmax = 1
        if cbar_ticks is None: cbar_ticks = [0, 1]
        if cbar_ticklabels is None: cbar_ticklabels = ['False', 'True']

    # Calculate limits of color scaling:
    interval = None
    if vmin is None or vmax is None:
        if allnan(image):
            logger.warning("Image is all NaN")
            vmin = 0
            vmax = 1
            if cbar_ticks is None:
                cbar_ticks = []
            if cbar_ticklabels is None:
                cbar_ticklabels = []
        elif isinstance(percentile, (list, tuple, np.ndarray)):
            interval = viz.AsymmetricPercentileInterval(
                percentile[0], percentile[1])
        else:
            interval = viz.PercentileInterval(percentile)

    # Create ImageNormalize object with extracted limits:
    if scale in ('log', 'linear', 'sqrt', 'asinh', 'histeq', 'sinh',
                 'squared'):
        if scale == 'log':
            stretch = viz.LogStretch()
        elif scale == 'linear':
            stretch = viz.LinearStretch()
        elif scale == 'sqrt':
            stretch = viz.SqrtStretch()
        elif scale == 'asinh':
            stretch = viz.AsinhStretch()
        elif scale == 'histeq':
            stretch = viz.HistEqStretch(image[np.isfinite(image)])
        elif scale == 'sinh':
            stretch = viz.SinhStretch()
        elif scale == 'squared':
            stretch = viz.SquaredStretch()

        # Create ImageNormalize object. Very important to use clip=False if the image contains
        # NaNs, otherwise NaN points will not be plotted correctly.
        norm = viz.ImageNormalize(data=image[np.isfinite(image)],
                                  interval=interval,
                                  vmin=vmin,
                                  vmax=vmax,
                                  stretch=stretch,
                                  clip=not anynan(image))

    elif isinstance(scale, (viz.ImageNormalize, matplotlib.colors.Normalize)):
        norm = scale
    else:
        raise ValueError("scale {} is not available.".format(scale))

    if offset_axes:
        extent = (offset_axes[0] - 0.5, offset_axes[0] + image.shape[1] - 0.5,
                  offset_axes[1] - 0.5, offset_axes[1] + image.shape[0] - 0.5)
    else:
        extent = (-0.5, image.shape[1] - 0.5, -0.5, image.shape[0] - 0.5)

    if ax is None:
        ax = plt.gca()

    # Set up the colormap to use. If a bad color is defined,
    # add it to the colormap:
    if cmap is None:
        cmap = copy.copy(plt.get_cmap('Blues'))
    elif isinstance(cmap, str):
        cmap = copy.copy(plt.get_cmap(cmap))

    if color_bad:
        cmap.set_bad(color_bad, 1.0)

    # Plotting the image using all the settings set above:
    im = ax.imshow(image,
                   cmap=cmap,
                   norm=norm,
                   origin=origin,
                   extent=extent,
                   interpolation='nearest',
                   **kwargs)

    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if title is not None:
        ax.set_title(title)
    ax.set_xlim([extent[0], extent[1]])
    ax.set_ylim([extent[2], extent[3]])

    if cbar:
        colorbar(im,
                 ax=ax,
                 loc=cbar,
                 size=cbar_size,
                 pad=cbar_pad,
                 label=clabel,
                 ticks=cbar_ticks,
                 ticklabels=cbar_ticklabels)

    # Settings for ticks:
    integer_locator = MaxNLocator(nbins=10, integer=True)
    ax.xaxis.set_major_locator(integer_locator)
    ax.xaxis.set_minor_locator(integer_locator)
    ax.yaxis.set_major_locator(integer_locator)
    ax.yaxis.set_minor_locator(integer_locator)
    ax.tick_params(which='both', direction='out', pad=5)
    ax.xaxis.tick_bottom()
    ax.yaxis.tick_left()

    return im
Example #9
0
def plot_image_fit_residuals(fig, image, fit, residuals=None, percentile=95.0):
    """
	Make a figure with three subplots showing the image, the fit and the
	residuals. The image and the fit are shown with logarithmic scaling and a
	common colorbar. The residuals are shown with linear scaling and a separate
	colorbar.

	Parameters:
		fig (fig object): Figure object in which to make the subplots.
		image (2D array): Image numpy array.
		fit (2D array): Fitted image numpy array.
		residuals (2D array, optional): Fitted image subtracted from image numpy array.

	Returns:
		list: List with Matplotlib subplot axes objects for each subplot.
	"""

    if residuals is None:
        residuals = image - fit

    # Calculate common normalization for the first two subplots:
    vmin_image, vmax_image = viz.PercentileInterval(percentile).get_limits(
        image)
    vmin_fit, vmax_fit = viz.PercentileInterval(percentile).get_limits(fit)
    vmin = np.nanmin([vmin_image, vmin_fit])
    vmax = np.nanmax([vmax_image, vmax_fit])
    norm = viz.ImageNormalize(vmin=vmin, vmax=vmax, stretch=viz.LogStretch())

    # Add subplot with the image:
    ax1 = fig.add_subplot(131)
    im1 = plot_image(image, ax=ax1, scale=norm, cbar=None, title='Image')

    # Add subplot with the fit:
    ax2 = fig.add_subplot(132)
    plot_image(fit, ax=ax2, scale=norm, cbar=None, title='PSF fit')

    # Calculate the normalization for the third subplot:
    vmin, vmax = viz.PercentileInterval(percentile).get_limits(residuals)
    v = np.max(np.abs([vmin, vmax]))

    # Add subplot with the residuals:
    ax3 = fig.add_subplot(133)
    im3 = plot_image(residuals,
                     ax=ax3,
                     scale='linear',
                     cmap='seismic',
                     vmin=-v,
                     vmax=v,
                     cbar=None,
                     title='Residuals')

    # Make the common colorbar for image and fit subplots:
    cbar_ax12 = fig.add_axes([0.125, 0.2, 0.494, 0.03])
    fig.colorbar(im1, cax=cbar_ax12, orientation='horizontal')

    # Make the colorbar for the residuals subplot:
    cbar_ax3 = fig.add_axes([0.7, 0.2, 0.205, 0.03])
    fig.colorbar(im3, cax=cbar_ax3, orientation='horizontal')

    # Add more space between subplots:
    plt.subplots_adjust(wspace=0.4, hspace=0.4)

    return [ax1, ax2, ax3]
Example #10
0
 def set_normalization(self,
                       stretch=None,
                       interval=None,
                       stretchkwargs={},
                       intervalkwargs={},
                       perm_linear=None):
     if stretch is None:
         if self.stretch is None:
             stretch = 'linear'
         else:
             stretch = self.stretch
     if isinstance(stretch, str):
         print(stretch,
               ' '.join([f'{k}={v}' for k, v in stretchkwargs.items()]))
         if self.data is None:  #can not calculate objects yet
             self.stretch_kwargs = stretchkwargs
         else:
             kwargs = self.prepare_kwargs(
                 self.stretch_kws_defaults[stretch], self.stretch_kwargs,
                 stretchkwargs)
             if perm_linear is not None:
                 perm_linear_kwargs = self.prepare_kwargs(
                     self.stretch_kws_defaults['linear'], perm_linear)
                 print(
                     'linear', ' '.join([
                         f'{k}={v}' for k, v in perm_linear_kwargs.items()
                     ]))
                 if stretch == 'asinh':  # arg: a=0.1
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.AsinhStretch(**kwargs))
                 elif stretch == 'contrastbias':  # args: contrast, bias
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.ContrastBiasStretch(**kwargs))
                 elif stretch == 'histogram':
                     stretch = vis.CompositeStretch(
                         vis.HistEqStretch(self.data, **kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'log':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LogStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'powerdist':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.PowerDistStretch(**kwargs))
                 elif stretch == 'power':  # args: a
                     stretch = vis.CompositeStretch(
                         vis.PowerStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'sinh':  # args: a=0.33
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SinhStretch(**kwargs))
                 elif stretch == 'sqrt':
                     stretch = vis.CompositeStretch(
                         vis.SqrtStretch(),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'square':
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SquaredStretch())
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
             else:
                 if stretch == 'linear':  # args: slope=1, intercept=0
                     stretch = vis.LinearStretch(**kwargs)
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
     self.stretch = stretch
     if interval is None:
         if self.interval is None:
             interval = 'zscale'
         else:
             interval = self.interval
     if isinstance(interval, str):
         print(interval,
               ' '.join([f'{k}={v}' for k, v in intervalkwargs.items()]))
         kwargs = self.prepare_kwargs(self.interval_kws_defaults[interval],
                                      self.interval_kwargs, intervalkwargs)
         if self.data is None:
             self.interval_kwargs = intervalkwargs
         else:
             if interval == 'minmax':
                 interval = vis.MinMaxInterval()
             elif interval == 'manual':  # args: vmin, vmax
                 interval = vis.ManualInterval(**kwargs)
             elif interval == 'percentile':  # args: percentile, n_samples
                 interval = vis.PercentileInterval(**kwargs)
             elif interval == 'asymetric':  # args: lower_percentile, upper_percentile, n_samples
                 interval = vis.AsymmetricPercentileInterval(**kwargs)
             elif interval == 'zscale':  # args: nsamples=1000, contrast=0.25, max_reject=0.5, min_npixels=5, krej=2.5, max_iterations=5
                 interval = vis.ZScaleInterval(**kwargs)
             else:
                 raise ValueError('Unknown interval:' + interval)
     self.interval = interval
     if self.img is not None:
         self.img.set_norm(
             vis.ImageNormalize(self.data,
                                interval=self.interval,
                                stretch=self.stretch,
                                clip=True))
def rgbfig(
        figfilename="SgrB2N_RGB.pdf",
        lims=[([266.83404223, 266.83172659]), ([-28.373138, -28.3698755])],
        scalebarx=coordinates.SkyCoord(266.833545 * u.deg,
                                       -28.37283819 * u.deg),
        redfn=paths.Fpath('SGRB2N-2012-Q.DePree_K.recentered.fits'),
        greenfn=paths.
    Fpath('sgr_b2m.N.B3.allspw.continuum.r0.5.clean1000.image.tt0.pbcor.fits'),
        bluefn=paths.
    Fpath('sgr_b2m.N.B6.allspw.continuum.r0.5.clean1000.image.tt0.pbcor.fits'),
        redpercentile=99.99,
        greenpercentile=99.99,
        bluepercentile=99.99,
        stretch=visualization.AsinhStretch(),
):

    header = fits.getheader(redfn)
    celwcs = wcs.WCS(header).celestial

    redhdu = fits.open(redfn)
    greenhdu = fits.open(greenfn)
    bluehdu = fits.open(bluefn)

    greendata, _ = reproject.reproject_interp(
        (greenhdu[0].data, wcs.WCS(greenhdu[0].header).celestial),
        celwcs,
        shape_out=redhdu[0].data.squeeze().shape)
    bluedata, _ = reproject.reproject_interp(
        (bluehdu[0].data, wcs.WCS(bluehdu[0].header).celestial),
        celwcs,
        shape_out=redhdu[0].data.squeeze().shape)

    #def rescale(x):
    #    return (x-np.nanmin(x))/(np.nanmax(x) - np.nanmin(x))
    redrescale = visualization.PercentileInterval(redpercentile)
    greenrescale = visualization.PercentileInterval(greenpercentile)
    bluerescale = visualization.PercentileInterval(bluepercentile)

    rgb = np.array([
        stretch(redrescale(redhdu[0].data.squeeze())),
        stretch(greenrescale(greendata)),
        stretch(bluerescale(bluedata)),
    ]).swapaxes(0, 2).swapaxes(0, 1)

    norm = visualization.ImageNormalize(
        rgb, interval=visualization.MinMaxInterval(), stretch=stretch)

    fig1 = pl.figure(1)
    fig1.clf()
    ax = fig1.add_subplot(1, 1, 1, projection=celwcs)
    pl.imshow(rgb, origin='lower', interpolation='none', norm=norm)

    (x1, x2), (y1, y2) = celwcs.wcs_world2pix(lims[0], lims[1], 0)
    ax.axis((x1, x2, y1, y2))

    visualization_tools.make_scalebar(ax,
                                      left_side=scalebarx,
                                      length=1.213 * u.arcsec,
                                      label='0.05 pc')

    pl.savefig(paths.fpath(figfilename), bbox_inches='tight')
def raft_level_signal_correlations(flat1_files,
                                   flat2_files,
                                   bias_frames,
                                   buffer=10,
                                   title='',
                                   vrange=None,
                                   stretch=viz.LinearStretch,
                                   figsize=(8, 8)):
    """
    Compute the correlation coefficients between the imaging section
    pixels for the difference images from a flat pair for the 144
    amplifiers in raft.

    Parameters
    ----------
    flat1_files: dict
        Dictionary of flat1 image files, indexed by sensor slot id.
        These should be from the same flat pair frame as the flat2_files.
    flat2_files: dict
        Dictionary of flat2 image files, indexed by sensor slot id.
    bias_frames: dict
        Dictionary of super bias frames, indexed by sensor slot id.
    buffer: int [10]
        Buffer region around perimeter of serial overscan region to
        avoid when computing the correlation coefficients.
    title: str ['']
        Plot title.
    vrange: (float, float) [None]
        Minimum and maximum values for color scale range. If None, then
        the range of the central 98th percentile of the absolute value
        of the data is used.
    stretch: astropy.visualization.BaseStretch [LinearStretch]
        Stretch to use for the color scale.

    Returns
    -------
    (matplotlib.figure.Figure, np.array): The figure containing the plot and
        the numpy array containing the correlation coefficients.
    """
    slots = 'S00 S01 S02 S10 S11 S12 S20 S21 S22'.split()
    segments = []

    ccd0 = sensorTest.MaskedCCD(list(flat1_files.values())[0])
    bbox = ccd0.amp_geom.imaging
    bbox.grow(-buffer)

    for slot in slots:
        if slot not in flat1_files:
            for amp in ccd0:
                segments.append(np.zeros((bbox.getHeight(), bbox.getWidth())))
        else:
            imarrs = diff_image_arrays(flat1_files[slot],
                                       flat2_files[slot],
                                       bias_frame=bias_frames[slot],
                                       buffer=buffer)
            for amp in imarrs:
                segments.append(imarrs[amp])
    namps = len(segments)
    data = np.array([
        np.corrcoef(segments[i[0]].ravel(), segments[i[1]].ravel())[0, 1]
        for i in itertools.product(range(namps), range(namps))
    ])
    data = data.reshape((namps, namps))
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    ax.set_title(title, fontsize='medium')

    interval = viz.PercentileInterval(98.)
    if vrange is None:
        vrange = interval.get_limits(np.abs(data.ravel()))
    norm = ImageNormalize(vmin=vrange[0], vmax=vrange[1], stretch=stretch())
    image = ax.imshow(data, interpolation='none', norm=norm)
    plt.colorbar(image)

    set_ticks(ax, slots, amps=16)

    return fig, data
def make_hiidust_plot(
    reg,
    mgpsfile,
    width=1 * u.arcmin,
    surveys=['atlasgal'],
    figure=None,
    regname='GAL_031',
    fifth_panel_synchro=False,
    alpha=-0.12,
    cmap=None,
):

    if cmap is None:
        cmap = pl.cm.viridis
        cmap.set_bad('w')

    mgps_fh = fits.open(mgpsfile)[0]
    frame = wcs.utils.wcs_to_celestial_frame(wcs.WCS(mgps_fh.header))

    coordinate = reg.center
    coordname = "{0:06.3f}_{1:06.3f}".format(coordinate.galactic.l.deg,
                                             coordinate.galactic.b.deg)

    mgps_cutout = Cutout2D(mgps_fh.data,
                           coordinate.transform_to(frame.name),
                           size=width * 2,
                           wcs=wcs.WCS(mgps_fh.header))
    print()
    print(reg.meta['text'])
    print(
        f"Retrieving MAGPIS data for {coordname} ({coordinate.to_string()} {coordinate.frame.name})"
    )
    # we're treating 'width' as a radius elsewhere, here it's a full width
    images = {
        survey: getimg(coordinate, image_size=width * 2, survey=survey)
        for survey in surveys
    }
    images = {x: y for x, y in images.items() if y is not None}
    assert len(images) > 0
    #images['mgps'] = [mgps_cutout]

    # coordinate stuff so images can be reprojected to same frame
    ww = mgps_cutout.wcs.celestial
    mgps_pixscale = (wcs.utils.proj_plane_pixel_area(ww) * u.deg**2)**0.5

    if figure is None:
        figure = pl.gcf()
    figure.clf()

    (survey, img), = images.items()

    new_img = img[0].data
    if hasattr(img[0], 'header'):
        outwcs = wcs.WCS(img[0].header)
    else:
        outwcs = img[0].wcs

    reproj_pixscale = (wcs.utils.proj_plane_pixel_area(outwcs) * u.deg**2)**0.5

    agal_bm = tgt_bm = Beam(beam_map[survey])
    convbm = tgt_bm.deconvolve(mgps_beam)

    mgps_sm = convolution.convolve_fft(mgps_cutout.data,
                                       convbm.as_kernel(mgps_pixscale))
    mgps_reproj, _ = reproject.reproject_interp((mgps_sm, mgps_cutout.wcs),
                                                outwcs,
                                                shape_out=img[0].data.shape)

    mgpsMjysr = mgps_cutout.data / mgps_beam.sr.value / 1e6

    dust_pred = dust_emissivity.blackbody.modified_blackbody(
        u.Quantity(
            [wlmap[survey].to(u.GHz, u.spectral()),
             mustang_central_frequency]),
        assumed_temperature,
        beta=assumed_dustbeta)

    # assumes "surv" is dust
    surv_to_mgps = new_img * dust_pred[1] / dust_pred[0]
    print(f"{regname} {survey}")
    print(f"{survey} to mgps ratio: {dust_pred[1]/dust_pred[0]}")

    dusty = surv_to_mgps.value / tgt_bm.sr.value / 1e6
    freefree = (mgps_reproj / mgps_beam.sr.value / 1e6 - dusty)
    assert not hasattr(freefree, 'unit')
    print("Max values: ", img[0].data.max(), mgps_sm.max())
    print("More max values: ", np.nanmax(dusty), np.nanmax(freefree),
          np.nanmax(mgps_reproj / mgps_beam.sr.value / 1e6))

    norm = visualization.ImageNormalize(
        freefree,
        interval=visualization.ManualInterval(np.nanpercentile(freefree, 0.1),
                                              np.nanpercentile(freefree,
                                                               99.9)),
        stretch=visualization.LogStretch(),
    )
    mgpsnorm = visualization.ImageNormalize(
        mgps_cutout.data,
        interval=visualization.PercentileInterval(99.95),
        stretch=visualization.LogStretch(),
    )
    print(f"interval: {norm.interval.vmin}, {norm.interval.vmax}")
    assert not hasattr(norm.vmin, 'unit')
    assert not hasattr(norm.vmax, 'unit')
    assert not hasattr(mgpsnorm.vmin, 'unit')
    assert not hasattr(mgpsnorm.vmax, 'unit')

    Magpis.cache_location = '/Volumes/external/mgps/cache/'

    ax0 = figure.add_subplot(1, 6, 3, projection=mgps_cutout.wcs)
    ax0.imshow(mgpsMjysr,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax0.set_title("3 mm")
    ax1 = figure.add_subplot(1, 6, 1, projection=outwcs)
    ax1.imshow(dusty,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax1.set_title("870 $\\mu$m scaled")
    ax1.set_ylabel("Galactic Latitude")
    ax2 = figure.add_subplot(1, 6, 2, projection=outwcs)
    ax2.imshow(freefree,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax2.set_title("3 mm Free-Free")

    for ax in (ax0, ax1, ax2):
        #ax.set_xlabel("Galactic Longitude")
        ax.tick_params(direction='in')
        ax.tick_params(color='w')

    ax0.coords[1].set_axislabel("")
    ax0.coords[1].set_ticklabel_visible(False)
    ax2.coords[1].set_axislabel("")
    ax2.coords[1].set_ticklabel_visible(False)

    pl.subplots_adjust(hspace=0, wspace=0)

    if 'G01' in regname:
        gps20im = fits.open('/Users/adam/work/gc/20cm_0.fits', )
    elif 'G49' in regname:
        gps20im = fits.open(
            '/Users/adam/work/w51/vla_old/W51-LBAND-feathered_ABCD.fits')
        #gps20im = fits.open('/Users/adam/work/w51/vla_old/W51-LBAND_Carray.fits')
    else:
        gps20im = getimg(coordinate, image_size=width * 2, survey='gps20new')

    reproj_gps20, _ = reproject.reproject_interp(
        (gps20im[0].data.squeeze(), wcs.WCS(gps20im[0].header).celestial),
        #mgps_fh.header)
        # refactoring to make a smaller cutout would make this faster....
        mgps_cutout.wcs,
        shape_out=mgps_cutout.data.shape)

    gps20cutout = Cutout2D(
        reproj_gps20,  #gps20im[0].data.squeeze(),
        coordinate.transform_to(frame.name),
        size=width * 2,
        wcs=mgps_cutout.wcs)
    #wcs=wcs.WCS(mgps_fh.header))
    #wcs.WCS(gps20im[0].header).celestial)
    ax3 = figure.add_subplot(1, 6, 5, projection=gps20cutout.wcs)

    gps20_bm = Beam.from_fits_header(gps20im[0].header)
    print(f"GPS 20 beam: {gps20_bm.__repr__()}")

    norm20 = visualization.ImageNormalize(
        gps20cutout.data,
        interval=visualization.ManualInterval(
            np.nanpercentile(gps20cutout.data, 0.5),
            np.nanpercentile(gps20cutout.data, 99.9)),
        stretch=visualization.LogStretch(),
    )

    # use 0.12 per Loren's suggestion
    freefree_20cm_to_3mm = (90 * u.GHz / (1.4 * u.GHz))**alpha

    gps20_Mjysr = gps20cutout.data / gps20_bm.sr.value / 1e6

    ax3.imshow((gps20_Mjysr * freefree_20cm_to_3mm).value,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax3.set_title("20 cm scaled")

    ax3.coords[1].set_axislabel("")
    ax3.coords[1].set_ticklabel_visible(False)
    ax3.tick_params(direction='in')
    ax3.tick_params(color='w')

    # Fifth Panel:

    # use freefree_proj to get the 20cm-estimated free-free contribution even
    # if we're not using it for plotting
    # MAGPIS data are high-resolution (comparable to but better than MGPS)
    # Zadeh data are low-resolution, 30ish arcsec
    # units: Jy/sr
    freefree_proj, _ = reproject.reproject_interp(
        (freefree, outwcs), gps20cutout.wcs, shape_out=gps20cutout.data.shape)

    gps20_pixscale = (wcs.utils.proj_plane_pixel_area(gps20cutout.wcs) *
                      u.deg**2)**0.5

    # depending on which image has higher resolution, convolve one to the other
    try:
        gps20convbm = tgt_bm.deconvolve(gps20_bm)
        gps20_Mjysr_sm = convolution.convolve_fft(
            gps20_Mjysr, gps20convbm.as_kernel(gps20_pixscale))
    except ValueError:
        gps20_Mjysr_sm = gps20_Mjysr
        ff_convbm = gps20_bm.deconvolve(tgt_bm)
        freefree_proj = convolution.convolve_fft(
            freefree_proj, ff_convbm.as_kernel(gps20_pixscale))

    if fifth_panel_synchro:

        ax4 = figure.add_subplot(1, 6, 5, projection=gps20cutout.wcs)

        # use the central frequency corresponding to an approximately flat spectrum (flat -> 89.72)
        freefree_3mm_to_20cm = 1 / (90 * u.GHz / (1.4 * u.GHz))**-0.12
        #empirical_factor = 3 # freefree was coming out way too high, don't understand why yet
        synchro = gps20_Mjysr_sm - freefree_proj * freefree_3mm_to_20cm
        synchro[np.isnan(gps20_Mjysr) | (gps20_Mjysr == 0)] = np.nan

        synchroish_ratio = gps20_Mjysr_sm / (freefree_proj *
                                             freefree_3mm_to_20cm)

        #synchro = synchroish_ratio

        normsynchro = visualization.ImageNormalize(
            gps20_Mjysr_sm,
            interval=visualization.ManualInterval(
                np.nanpercentile(gps20_Mjysr_sm, 0.5),
                np.nanpercentile(gps20_Mjysr_sm, 99.9)),
            stretch=visualization.LogStretch(),
        )

        ax4.imshow(synchro.value,
                   origin='lower',
                   interpolation='none',
                   norm=normsynchro,
                   cmap=cmap)
        ax4.set_title("Synchrotron")
        ax4.tick_params(direction='in')
        ax4.tick_params(color='w')
        ax4.coords[1].set_axislabel("")
        ax4.coords[1].set_ticklabel_visible(False)

        pl.tight_layout()
    else:
        # scale 20cm to match MGPS and subtract it

        gps20_pixscale = (wcs.utils.proj_plane_pixel_area(gps20cutout.wcs) *
                          u.deg**2)**0.5

        if gps20_bm.sr < mgps_beam.sr:
            # smooth GPS20 to MGPS
            gps20convbm = mgps_beam.deconvolve(gps20_bm)
            gps20_Mjysr_sm = convolution.convolve_fft(
                gps20_Mjysr, gps20convbm.as_kernel(gps20_pixscale))
            gps20_Mjysr_sm[~np.isfinite(gps20_Mjysr)] = np.nan
            gps20_proj = gps20_Mjysr_sm
            #gps20_proj,_ = reproject.reproject_interp((gps20_Mjysr_sm, gps20cutout.wcs),
            #                                          ww,
            #                                          shape_out=mgps_cutout.data.shape)
        else:
            gps20_proj = gps20_Mjysr
            gps20_convbm = gps20_bm.deconvolve(mgps_beam)
            mgpsMjysr = convolution.convolve_fft(
                mgpsMjysr, gps20_convbm.as_kernel(mgps_pixscale))

        ax4 = figure.add_subplot(1, 6, 4, projection=mgps_cutout.wcs)

        # use the central frequency corresponding to an approximately flat spectrum (flat -> 89.72)
        freefree20 = gps20_proj * freefree_20cm_to_3mm
        dust20 = (mgpsMjysr - freefree20).value
        dust20[np.isnan(gps20_proj) | (gps20_proj == 0)] = np.nan

        normdust20 = visualization.ImageNormalize(
            mgpsMjysr,
            interval=visualization.ManualInterval(
                np.nanpercentile(mgpsMjysr, 0.5),
                np.nanpercentile(mgpsMjysr, 99.9)),
            stretch=visualization.LogStretch(),
        )

        # show smoothed 20 cm
        ax3.imshow((freefree20).value,
                   origin='lower',
                   interpolation='none',
                   norm=norm,
                   cmap=cmap)
        ax4.imshow(dust20,
                   origin='lower',
                   interpolation='none',
                   norm=norm,
                   cmap=cmap)
        ax4.set_title("3 mm Dust")
        ax4.tick_params(direction='in')
        ax4.tick_params(color='w')
        ax4.coords[1].set_axislabel("")
        ax4.coords[1].set_ticklabel_visible(False)

        pl.tight_layout()

    #elif 'G01' not in regname:
    #    norm.vmin = np.min([np.nanpercentile(dust20, 0.5), np.nanpercentile(freefree, 0.1)])
    if np.abs(np.nanpercentile(dust20, 0.5) -
              np.nanpercentile(freefree, 0.1)) < 1e2:
        norm.vmin = np.min(
            [np.nanpercentile(dust20, 0.5),
             np.nanpercentile(freefree, 0.1)])
    if 'arches' in reg.meta['text']:
        norm.vmin = 0.95  # force 1 to be on-scale
    if 'w49b' in reg.meta['text']:
        norm.vmin = np.min(
            [np.nanpercentile(dust20, 8),
             np.nanpercentile(freefree, 0.1)])
        norm.vmin = -4
        norm.vmax = 11

    ax0.imshow(mgpsMjysr,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax1.imshow(dusty,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax2.imshow(freefree,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax3.imshow((gps20_proj * freefree_20cm_to_3mm).value,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)
    ax4.imshow(dust20,
               origin='lower',
               interpolation='none',
               norm=norm,
               cmap=cmap)

    print(
        f"{reg}: dusty sum: {dusty[dusty>0].sum()}   freefreeish sum: {freefree[freefree>0].sum()}"
    )

    area = mgps_reproj.size * (reproj_pixscale**2).to(u.sr)
    mgps_reproj_Mjysr = mgps_reproj / mgps_beam.sr.value / 1e6

    # only label the middle axis
    for ax in figure.axes:
        ax.set_xlabel("Galactic Longitude")
    for ax in figure.axes:
        ax.set_xlabel(" ")

    ax0.set_xlabel("Galactic Longitude")

    lastax = ax3
    bbox = lastax.get_position()

    # this is a painful hack to force the bbox to update
    while bbox.height > 0.9:
        print(f"bbox_height = {bbox.height}")
        pl.pause(0.1)
        bbox = lastax.get_position()

    cax = figure.add_axes([bbox.x1 + 0.01, bbox.y0, 0.02, bbox.height])
    cb = figure.colorbar(mappable=lastax.images[-1], cax=cax)
    cb.set_ticks([-3, 0, 10, 50, 100])
    if 'w51' in reg.meta['text']:
        cb.set_ticks([-10, 0, 20, 200])
    if 'w49b' in reg.meta['text']:
        cb.set_ticks([-3, 0, 3, 10])
    if 'arches' in reg.meta['text']:
        cb.set_ticks([0, 1, 5, 10])
    cb.set_label('MJy sr$^{-1}$')

    return {
        'dust': dusty[dusty > 0].sum(),
        'dust20': dust20[dust20 > 0].sum(),
        'freefree': freefree[freefree > 0].sum(),
        'freefree20': freefree20[freefree20 > 0].sum(),
        'totalpos': mgps_reproj_Mjysr[mgps_reproj_Mjysr > 0].sum(),
        'total': mgps_reproj_Mjysr.sum(),
        'totalpos20': mgpsMjysr[mgpsMjysr > 0].sum(),
        'total20': mgpsMjysr.sum(),
    }
Example #14
0
def timegen(in_path, out_path, m, n, cell, stretch, full_hd):
    """ Generates a timelapse from the input FITS files (directory) and saves it to the given path. \n
        ---------- \n
        parameters \n
        ---------- \n
        in_path  : The path to the directory containing the input FITS files (*.fits or *.fits.fz) \n
        out_path : The path at which the output timelapse will be saved. If unspecified writes to .\\timelapse \n
        m        : Number of rows to split image into \n
        n        : Number of columns to split image into \n
        cell     : The grid cell to choose. Specified by row and column indices. (0,1)
        stretch  : String specifying what stretches to apply on the image
        full_hd : Force the video to be 1920 * 1080 pixels.    
        ---------- \n
        returns \n
        ---------- \n
        True if timelapse generated successfully. \n
    """
    # Step 1: Get FITS files from input path.
    fits_files = get_file_list(in_path, ['fits', 'fz'])

    # Step 1.5: Remove files containing the string 'background' from the FITS filename
    fits_files = [
        fname for fname in fits_files if 'background.fits' not in fname
    ]
    fits_files = [
        fname for fname in fits_files if 'pointing00.fits' not in fname
    ]

    # Step 2: Choose the transform you want to apply.
    # TG_LOG_1_PERCENTILE_99
    transform = v.LogStretch(1) + v.PercentileInterval(99)
    if stretch == 'TG_SQRT_PERCENTILE_99':
        transform = v.SqrtStretch() + v.PercentileInterval(99)
    elif stretch == 'TG_LOG_PERCENTILE_99':
        transform = v.LogStretch() + v.PercentileInterval(99)
    elif stretch == 'TG_ASINH_1_PERCENTILE_99':
        transform = v.AsinhStretch(1) + v.PercentileInterval(99)
    elif stretch == 'TG_ASINH_PERCENTILE_99':
        transform = v.AsinhStretch() + v.PercentileInterval(99)
    elif stretch == 'TG_SQUARE_PERCENTILE_99':
        transform = v.SquaredStretch() + v.PercentileInterval(99)
    elif stretch == 'TG_SINH_1_PERCENTILE_99':
        transform = v.SinhStretch(1) + v.PercentileInterval(99)
    else:
        transform = v.SinhStretch() + v.PercentileInterval(99)

    # Step 3:
    for file in tqdm.tqdm(fits_files):
        # Read FITS
        try:
            fits_data = fits.getdata(file)
        except Exception as e:
            # If the current FITS file can't be opened, log and skip it.
            logging.error(str(e))
            continue
        # Flip up down
        flipped_data = np.flipud(fits_data)
        # Debayer with 'RGGB'
        rgb_data = debayer_image_array(flipped_data, pattern='RGGB')
        interested_data = get_sub_image(rgb_data, m, n, cell[0], cell[1])
        # Additional processing
        interested_data = 255 * transform(interested_data)
        rgb_data = interested_data.astype(np.uint8)
        bgr_data = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2BGR)
        # save processed image to temp_dir
        try:
            save_image(interested_data,
                       os.path.split(file)[-1].split('.')[0],
                       path=temp_dir)
        except Exception as e:
            logging.error(str(e))
    # Step 4: Validate output path and create if it doesn't exist.

    # Step 5: Create timelapse from the temporary files
    generate_timelapse_from_images('temp_timelapse', out_path, hd_flag=full_hd)

    # Delete temporary files
    try:
        clear_dir(temp_dir)
    except Exception as e:
        print('Clearing TEMP Files failed. See log for more details')
        logging.error(str(e))
    return True
Example #15
0
def generate_verification_page(lcd, ls, freq, power, cutoutpaths, c_obj,
                               outvppath, outd, show_binned=True):
    """
    Make the verification page, which consists of:

    top row: entire light curve (with horiz bar showing rotation period)

    bottom row:
        lomb scargle periodogram  |  phased light curve  |  image w/ aperture

    ----------
    args:

        lcd (dict): has the light curve, aperture positions, some lomb
        scargle results.

        ls: LombScargle instance with everything passed.

        cutoutpaths (list): FFI cutout FITS paths.

        c_obj (SkyCoord): astropy sky coordinate of the target

        outvppath (str): path to save verification page to
    """
    cutout_wcs = lcd['cutout_wcss'][0]

    mpl.rcParams['xtick.direction'] = 'in'
    mpl.rcParams['ytick.direction'] = 'in'

    plt.close('all')

    fig = plt.figure(figsize=(12,12))

    #ax0 = plt.subplot2grid((3, 3), (0, 0), colspan=3)
    #ax1 = plt.subplot2grid((3, 3), (1, 0), colspan=3)
    #ax2 = plt.subplot2grid((3, 3), (2, 0))
    #ax3 = plt.subplot2grid((3, 3), (2, 1))
    #ax4 = plt.subplot2grid((3, 3), (2, 2), projection=cutout_wcs)

    ax0 = plt.subplot2grid((3, 3), (1, 0), colspan=3)
    ax1 = plt.subplot2grid((3, 3), (2, 0), colspan=3)
    ax2 = plt.subplot2grid((3, 3), (0, 0))
    ax3 = plt.subplot2grid((3, 3), (0, 1))
    ax4 = plt.subplot2grid((3, 3), (0, 2), projection=cutout_wcs)

    #
    # row 0: entire light curve, pre-detrending (with horiz bar showing
    # rotation period). plot model LC too.
    #
    try:
        ax0.scatter(lcd['predetrending_time'], lcd['predetrending_rel_flux'],
                    c='k', alpha=1.0, zorder=3, s=10, rasterized=True,
                    linewidths=0)
    except KeyError as e:
        print('ERR! {}\nReturning.'.format(e))
        return


    try:
        model_flux = nparr(lcd['predetrending_rel_flux']/lcd['rel_flux'])
    except ValueError:
        model_flux = 0

    if isinstance(model_flux, np.ndarray):
        ngroups, groups = find_lc_timegroups(lcd['predetrending_time'], mingap=0.5)
        for group in groups:
            ax0.plot(lcd['predetrending_time'][group], model_flux[group], c='C0',
                     alpha=1.0, zorder=2, rasterized=True, lw=2)

    # add the bar showing the derived period
    ymax = np.percentile(lcd['predetrending_rel_flux'], 95)
    ymin = np.percentile(lcd['predetrending_rel_flux'], 5)
    ydiff = 1.15*(ymax-ymin)

    epoch = np.nanmin(lcd['predetrending_time']) + lcd['ls_period']
    ax0.plot([epoch, epoch+lcd['ls_period']], [ymax, ymax], color='red', lw=2,
             zorder=4)

    ax0.set_ylim((ymin-ydiff,ymax+ydiff))

    #ax0.set_xlabel('Time [BJD$_{\mathrm{TDB}}$]')
    ax0.set_xticklabels('')
    ax0.set_ylabel('Raw flux')

    name = outd['name']
    group_id = outd['group_id']
    if name=='nan':
        nstr = 'Group {}'.format(group_id)
    else:
        nstr = '{}'.format(name)


    if not np.isfinite(outd['teff']):
        outd['teff'] = 0

    ax0.text(0.98, 0.97,
        'Teff={:d}K. {}'.format(int(outd['teff']), nstr),
             ha='right', va='top', fontsize='large', zorder=2,
             transform=ax0.transAxes
    )

    #
    # row 1: entire light curve (with horiz bar showing rotation period)
    #
    ax1.scatter(lcd['time'], lcd['rel_flux'], c='k', alpha=1.0, zorder=2, s=10,
                rasterized=True, linewidths=0)

    # add the bar showing the derived period
    ymax = np.percentile(lcd['rel_flux'], 95)
    ymin = np.percentile(lcd['rel_flux'], 5)
    ydiff = 1.15*(ymax-ymin)

    epoch = np.nanmin(lcd['time']) + lcd['ls_period']
    ax1.plot([epoch, epoch+lcd['ls_period']], [ymax, ymax], color='red', lw=2)

    ax1.set_ylim((ymin-ydiff,ymax+ydiff))

    ax1.set_xlabel('Time [BJD$_{\mathrm{TDB}}$]')
    ax1.set_ylabel('Detrended flux')

    #
    # row 2, col 0: lomb scargle periodogram
    #
    ax2.plot(1/freq, power, c='k')
    ax2.set_xscale('log')
    ax2.text(0.03, 0.97, 'FAP={:.1e}\nP={:.1f}d'.format(
        lcd['ls_fap'], lcd['ls_period']), ha='left', va='top',
        fontsize='large', zorder=2, transform=ax2.transAxes
    )
    ax2.set_xlabel('Period [day]', labelpad=-1)
    ax2.set_ylabel('LS power')

    #
    # row 2, col 1: phased light curve 
    #
    phzd = phase_magseries(lcd['time'], lcd['rel_flux'], lcd['ls_period'],
                           lcd['time'][np.argmin(lcd['rel_flux'])], wrap=False,
                           sort=True)

    ax3.scatter(phzd['phase'], phzd['mags'], c='k', rasterized=True, s=10,
                linewidths=0, zorder=1)

    if show_binned:
        try:
            binphasedlc = phase_bin_magseries(phzd['phase'], phzd['mags'],
                                              binsize=1e-2, minbinelems=5)
            binplotphase = binphasedlc['binnedphases']
            binplotmags = binphasedlc['binnedmags']

            ax3.scatter(binplotphase, binplotmags, s=10, c='darkorange',
                        linewidths=0, zorder=3, rasterized=True)
        except TypeError as e:
            print(e)
            pass

    xlim = ax3.get_xlim()
    ax3.hlines(1.0, xlim[0], xlim[1], colors='gray', linestyles='dotted',
               zorder=2)
    ax3.set_xlim(xlim)

    ymax = np.percentile(lcd['rel_flux'], 95)
    ymin = np.percentile(lcd['rel_flux'], 5)
    ydiff = 1.15*(ymax-ymin)
    ax3.set_ylim((ymin-ydiff,ymax+ydiff))

    ax3.set_xlabel('Phase', labelpad=-1)
    ax3.set_ylabel('Flux', labelpad=-0.5)

    #
    # row2, col2: image w/ aperture. put on the nbhr stars as dots too, to
    # ensure the wcs isn't wonky!
    #

    # acquire neighbor stars.
    radius = 2.0*u.arcminute

    nbhr_stars = Catalogs.query_region(
        "{} {}".format(float(c_obj.ra.value), float(c_obj.dec.value)),
        catalog="TIC",
        radius=radius
    )

    try:
        Tmag_cutoff = 15
        px,py = cutout_wcs.all_world2pix(
            nbhr_stars[nbhr_stars['Tmag'] < Tmag_cutoff]['ra'],
            nbhr_stars[nbhr_stars['Tmag'] < Tmag_cutoff]['dec'],
            0
        )
    except Exception as e:
        print('ERR! wcs all_world2pix got {}'.format(repr(e)))
        return

    tmags = nbhr_stars[nbhr_stars['Tmag'] < Tmag_cutoff]['Tmag']

    sel = (px > 0) & (px < 19) & (py > 0) & (py < 19)
    px,py = px[sel], py[sel]
    tmags = tmags[sel]

    ra, dec = float(c_obj.ra.value), float(c_obj.dec.value)
    target_x, target_y = cutout_wcs.all_world2pix(ra,dec,0)

    #
    # finally make it
    #

    img = lcd['median_imgs'][0]

    # some images come out as nans.
    if np.all(np.isnan(img)):
        img = np.ones_like(img)

    interval = vis.PercentileInterval(99.9)
    vmin,vmax = interval.get_limits(img)
    norm = vis.ImageNormalize(
        vmin=vmin, vmax=vmax, stretch=vis.LogStretch(1000))

    cset = ax4.imshow(img, cmap='YlGnBu_r', origin='lower', zorder=1,
                      norm=norm)

    ax4.scatter(px, py, marker='x', c='r', s=5, rasterized=True, zorder=2,
                linewidths=1)
    ax4.plot(target_x, target_y, mew=0.5, zorder=5, markerfacecolor='yellow',
             markersize=7, marker='*', color='k', lw=0)

    #ax4.coords.grid(True, color='white', ls='dotted', lw=1)
    lon = ax4.coords['ra']
    lat = ax4.coords['dec']

    lon.set_ticks(spacing=1*u.arcminute)
    lat.set_ticks(spacing=1*u.arcminute)

    lon.set_ticklabel(exclude_overlapping=True)
    lat.set_ticklabel(exclude_overlapping=True)

    ax4.coords.grid(True, color='white', alpha=0.3, lw=0.3, ls='dotted')

    #cb0 = fig.colorbar(cset, ax=ax4, extend='neither', fraction=0.046, pad=0.04)

    # overplot aperture
    radius_px = 3
    circle = plt.Circle((target_x, target_y), radius_px,
                         color='C1', fill=False, zorder=5)
    ax4.add_artist(circle)

    #
    # cleanup
    # 
    for ax in [ax0,ax1,ax2,ax3,ax4]:
        ax.get_yaxis().set_tick_params(which='both', direction='in',
                                       labelsize='small', top=True, right=True)
        ax.get_xaxis().set_tick_params(which='both', direction='in',
                                       labelsize='small', top=True, right=True)

    fig.tight_layout(w_pad=0.5, h_pad=0)

    #
    # save
    #
    fig.savefig(outvppath, dpi=300, bbox_inches='tight')
    print('made {}'.format(outvppath))