示例#1
0
def main():

    filename = get_stack_filename()

    ad = astrodata.open(filename)

    data = ad[0].data
    mask = ad[0].mask
    header = ad[0].hdr

    masked_data = np.ma.masked_where(mask, data, copy=True)

    palette = copy(plt.cm.viridis)
    palette.set_bad('gray')

    norm_factor = visualization.ImageNormalize(
        masked_data,
        stretch=visualization.LinearStretch(),
        interval=visualization.ZScaleInterval(),
    )

    fig, ax = plt.subplots(subplot_kw={'projection': wcs.WCS(header)})

    ax.imshow(masked_data,
              cmap=palette,
              vmin=norm_factor.vmin,
              vmax=norm_factor.vmax)

    ax.set_title(os.path.basename(filename))
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    fig.savefig(filename.replace('.fits', '.png'))
    plt.show()
示例#2
0
def uvtracks_airydisk2D(tel_tracks, veritas_tels, baselines, airy_func,
                        guess_r, wavelength, save_dir, star_name):
    x_0 = int(np.max(np.abs(tel_tracks)) * 1.2)
    y_0 = int(np.max(np.abs(tel_tracks)) * 1.2)
    airy_disk, airy_funcd = IImodels.airy_disk2D(shape=(x_0, y_0),
                                                 xpos=x_0,
                                                 ypos=y_0,
                                                 angdiam=1.22 * wavelength /
                                                 airy_func.radius.value,
                                                 wavelength=wavelength)
    y, x = np.mgrid[:x_0 * 2, :y_0 * 2]
    y, x = np.mgrid[:x_0 * 2, :y_0 * 2]
    airy_disk = airy_funcd(x, y)
    fig = plt.figure(figsize=(18, 12))

    plt.imshow(airy_disk,
               norm=viz.ImageNormalize(airy_disk, stretch=viz.LogStretch()),
               extent=[-x_0, x_0, -y_0, y_0],
               cmap='gray')
    for i, track in enumerate(tel_tracks):
        plt.plot(track[0][:, 0], track[0][:, 1], linewidth=6, color='b')
        # plt.text(track[0][:, 0][5], track[0][:, 1][5], "Baseline %s" % (baselines[i]), fontsize=14, color='w')
        plt.plot(track[1][:, 0], track[1][:, 1], linewidth=6, color='b')
        # plt.text(track[1][:, 0][5], track[1][:, 1][5], "Baseline %s" % (-baselines[i]), fontsize=14, color='w')
    starttime = veritas_tels.time_info.T + veritas_tels.observable_times[0]
    endtime = veritas_tels.time_info.T + veritas_tels.observable_times[-1]
    title = "Coverage of %s at VERITAS \non %s UTC" % (
        star_name, veritas_tels.time_info.T)
    # plt.title(star_name, fontsize=28)

    plt.xlabel("U (m)", fontsize=36)
    plt.ylabel("V (m)", fontsize=36)
    plt.tick_params(axis='both',
                    which='major',
                    labelsize=28,
                    length=10,
                    width=4)
    plt.tick_params(axis='both',
                    which='minor',
                    labelsize=28,
                    length=10,
                    width=4)
    plt.tick_params(which="major", labelsize=24, length=8, width=3)
    plt.tick_params(which="minor", length=6, width=2)
    cbar = plt.colorbar()
    cbar.ax.tick_params(labelsize=24, length=6, width=3)

    if save_dir:
        graph_saver(
            save_dir,
            "CoverageOf%sOn%sUTC" % (star_name, veritas_tels.time_info.T))
    else:
        plt.show()
示例#3
0
def norm(image, interval='minmax', stretch='linear'):
    interval_kinds = {
        'zscale': astroviz.ZScaleInterval,
        'minmax': astroviz.MinMaxInterval,
    }
    stretch_kinds = {
        'linear': astroviz.LinearStretch,
        'log': astroviz.LogStretch,
    }
    norm = astroviz.ImageNormalize(image,
                                   interval=interval_kinds[interval](),
                                   stretch=stretch_kinds[stretch]())
    return norm
示例#4
0
 def get_normalization(self):
     if not isinstance(self.interval, vis.BaseInterval) or not isinstance(
             self.stretch, vis.BaseStretch):
         if self.get_selected_stretch_from_combobox() == 'linear':
             self.set_normalization(self.stretch, self.interval)
         else:
             self.set_normalization(
                 self.stretch,
                 self.interval,
                 perm_linear=self.scale_model.dictionary['linear'])
     return vis.ImageNormalize(self.data,
                               interval=self.interval,
                               stretch=self.stretch,
                               clip=True)
示例#5
0
def select_cutout(image, wcs):

    #I'm looking for how many pointings are in the mosaic
    #I don't know if it is always accurate
    nrow = image.shape[0] // 300.
    ncol = image.shape[1] // 300.
    #measuring the exact width of a row and of a column
    drow = image.shape[0] / nrow
    dcol = image.shape[1] / ncol
    #I'm showing the image to select the correct section
    #I'm picking the center with a mouse click (maybe)

    fig, ax = plt.subplots(1, 1)
    interval = vis.PercentileInterval(99.9)
    vmin, vmax = interval.get_limits(image)
    norm = vis.ImageNormalize(vmin=vmin,
                              vmax=vmax,
                              stretch=vis.LogStretch(1000))
    ax.imshow(image, cmap=plt.cm.Greys, norm=norm, origin='lower')
    for x in np.arange(0, image.shape[1], dcol):
        ax.axvline(x)
    for y in np.arange(0, image.shape[0], drow):
        ax.axhline(y)

    def onclick(event):

        ix, iy = event.xdata, event.ydata
        col = ix // 300.
        row = iy // 300.
        print(col, row)
        global x_cen, y_cen
        x_cen = 150 + 300 * (col)  #x of the center of the quadrans
        y_cen = 150 + 300 * (row)  #y of the center of thw quadrans
        print('x: {:3.0f}, y: {:3.0f}'.format(x_cen, y_cen))
        if event.key == 'q':
            fig.canvas.mpl_disconnect(cid)

    cid = fig.canvas.mpl_connect('button_press_event', onclick)
    plt.show()

    nrow = image.shape[0] // 300.
    ncol = image.shape[1] // 300.
    print(image.shape[0] / nrow)
    x = int(x_cen)
    y = int(y_cen)
    print(x, y)
    cutout = Cutout2D(image, (x, y),
                      size=(image.shape[0] / nrow - 20) * u.pixel,
                      wcs=wcs)
    return cutout
示例#6
0
文件: wcs.py 项目: LiuDezi/pyDANDIA
def diagnostic_plots(output_dir, hdu, image_wcs, detected_sources,
                     catalog_sources_xy, matched_stars, offset_x, offset_y):
    """Function to output plots used to assess and debug the astrometry
    performed on the reference image"""

    norm = visualization.ImageNormalize(hdu[0].data, \
                        interval=visualization.ZScaleInterval())

    fig = plt.figure(1)
    plt.imshow(hdu[0].data, origin='lower', cmap=plt.cm.viridis, norm=norm)
    plt.plot(detected_sources[:,1],detected_sources[:,2],'o',markersize=2,\
             markeredgewidth=1,markeredgecolor='b',markerfacecolor='None')
    plt.plot(catalog_sources_xy[:, 0], catalog_sources_xy[:, 1], 'r+')
    plt.xlabel('X pixel')
    plt.ylabel('Y pixel')
    plt.savefig(path.join(output_dir, 'reference_detected_sources_pixels.png'))
    plt.close(1)

    fig = plt.figure(1)
    fig.add_subplot(111, projection=image_wcs)
    plt.imshow(hdu[0].data, origin='lower', cmap=plt.cm.viridis, norm=norm)
    plt.plot(detected_sources[:,1],detected_sources[:,2],'o',markersize=2,\
             markeredgewidth=1,markeredgecolor='b',markerfacecolor='None')
    plt.plot(catalog_sources_xy[:, 0], catalog_sources_xy[:, 1], 'r+')
    plt.xlabel('RA [J2000]')
    plt.ylabel('Dec [J2000]')
    plt.savefig(path.join(output_dir, 'reference_detected_sources_world.png'))
    plt.close(1)

    fig = plt.figure(2)
    plt.subplot(211)
    plt.subplots_adjust(left=0.125,
                        bottom=0.1,
                        right=0.9,
                        top=0.95,
                        wspace=0.1,
                        hspace=0.3)
    plt.hist((matched_stars[:, 5] - matched_stars[:, 2]), 50)
    (xmin, xmax, ymin, ymax) = plt.axis()
    plt.plot([offset_x, offset_x], [ymin, ymax], 'r-')
    plt.xlabel('(Detected-catalog) X pixel')
    plt.ylabel('Frequency')
    plt.subplot(212)
    plt.hist((matched_stars[:, 4] - matched_stars[:, 1]), 50)
    (xmin, xmax, ymin, ymax) = plt.axis()
    plt.plot([offset_y, offset_y], [ymin, ymax], 'r-')
    plt.xlabel('(Detected-catalog) Y pixel')
    plt.ylabel('Frequency')
    plt.savefig(path.join(output_dir, 'astrometry_separations.png'))
    plt.close(1)
示例#7
0
def main():
    args = _parse_args()
    filename = args.filename

    ad = astrodata.open(filename)

    data = ad[0].data
    mask = ad[0].mask
    header = ad[0].hdr

    if args.mask:
        masked_data = np.ma.masked_where(mask, data, copy=True)
    else:
        masked_data = data

    palette = copy(plt.cm.viridis)
    palette.set_bad('Gainsboro')

    norm_factor = visualization.ImageNormalize(
        masked_data,
        stretch=visualization.LinearStretch(),
        interval=visualization.ZScaleInterval(),
    )

    fig = plt.figure(num=filename)
    ax = fig.subplots(subplot_kw={"projection": wcs.WCS(header)})

    print(norm_factor.vmin)
    print(norm_factor.vmax)
    ax.imshow(
        masked_data,
        cmap=palette,
        #vmin=norm_factor.vmin,
        #vmax=norm_factor.vmax,
        vmin=750.,
        vmax=900.,
        origin='lower')

    ax.set_title(os.path.basename(filename))

    ax.coords[0].set_axislabel('Right Ascension')
    ax.coords[0].set_ticklabel(fontsize='small')

    ax.coords[1].set_axislabel('Declination')
    ax.coords[1].set_ticklabel(rotation='vertical', fontsize='small')

    fig.tight_layout(rect=[0.05, 0, 1, 1])
    fig.savefig(os.path.basename(filename.replace('.fits', '.png')))
    plt.show()
示例#8
0
def build_psf_mask(setup, psf_size, diagnostics=False):
    """Function to construct a mask for the PSF of a single star, which
    is a 2D image array with 1.0 at all pixel locations within the PSF and 
    zero everywhere outside it."""

    half_psf = int(psf_size)
    half_psf2 = half_psf * half_psf

    pxmax = 2 * half_psf + 1
    pymax = pxmax

    psf_mask = np.ones([pymax, pxmax])

    pxmin = -half_psf
    pxmax = half_psf + 1
    pymin = -half_psf
    pymax = half_psf + 1

    for dx in range(pxmin, pxmax, 1):

        dx2 = dx * dx

        for dy in range(pymin, pymax, 1):

            if (dx2 + dy * dy) > half_psf2:
                psf_mask[dx + half_psf, dy + half_psf] = 0.0

    if diagnostics == True:
        fig = plt.figure(1)

        norm = visualization.ImageNormalize(psf_mask, \
                            interval=visualization.ZScaleInterval())

        plt.imshow(psf_mask, origin='lower', cmap=plt.cm.viridis, norm=norm)

        plt.xlabel('X pixel')

        plt.ylabel('Y pixel')

        plt.savefig(path.join(setup.red_dir, 'psf_mask.png'))

        plt.close(1)

    return psf_mask
def main():

    filename = get_stack_filename()

    ad = astrodata.open(filename)

    fig = plt.figure(num=filename, figsize=(7, 4.5))
    fig.suptitle(os.path.basename(filename), y=0.97)

    axs = fig.subplots(1, len(ad), sharey=True)

    palette = copy(plt.cm.viridis)
    palette.set_bad("Gainsboro", 1.0)

    norm = visualization.ImageNormalize(
        np.dstack([ext.data for ext in ad]),
        stretch=visualization.LinearStretch(),
        interval=visualization.ZScaleInterval()
    )

    print(norm.vmin)
    print(norm.vmax)
    for i in range(len(ad)):

        axs[i].imshow(
            # np.ma.masked_where(ad[i].mask > 0, ad[i].data),
            ad[i].data,
            #norm=colors.Normalize(vmin=norm.vmin, vmax=norm.vmax),
            norm=colors.Normalize(vmin=750, vmax=900),
            origin="lower",
            cmap=palette,
        )

        axs[i].set_xlabel('d{:02d}'.format(i+1))
        axs[i].set_xticks([])

    axs[i].set_yticks([])

    fig.tight_layout(rect=[0, 0, 1, 1], w_pad=0.05)

    fig.savefig(os.path.basename(filename.replace('.fits', '.png')))
    plt.show()
    def normalizeImage(cls, image):
        """Normalizes the image data to the [0,1] domain, using histogram
        equalization.

        Parameters
        ----------
        image : `np.array`
            Image.

        Returns
        -------
        norm : `np.array`
            Normalized image.
        """
        # TODO: make things like these configurable (also see resize in
        # store_thumbnail)
        stretch = aviz.HistEqStretch(image)
        norm = aviz.ImageNormalize(image, stretch=stretch, clip=True)

        return norm(image)
示例#11
0
def display_airy_disk(veritas_array, angd, wavelength, save_dir):
    airy_disk, airy_func = IImodels.airy_disk2D(shape=(veritas_array.xlen,
                                                       veritas_array.ylen),
                                                xpos=veritas_array.xlen / 2,
                                                ypos=veritas_array.ylen / 2,
                                                angdiam=angd,
                                                wavelength=wavelength)
    # norm = viz.ImageNormalize(1, stretch=viz.LogStretch())

    plt.figure(figsize=(80, 80))

    plt.title("The Airy disk of a %s Point Source" % (angd))
    plt.xlabel("Meters")
    plt.ylabel("Meters")
    norm = viz.ImageNormalize(airy_disk, stretch=viz.SqrtStretch())

    plt.imshow(airy_disk, norm=norm, cmap="gray")

    if save_dir:
        graph_saver(save_dir, "AiryDisk")
    else:
        plt.show()
示例#12
0
def test_find_psf_companion_stars():
    """Function to test the identification of stars that neighbour a PSF star 
    from the reference catalogue."""

    setup = pipeline_setup.pipeline_setup({'red_dir': TEST_DIR})

    log = logs.start_stage_log(cwd, 'test_find_psf_companions')

    log.info(setup.summary())

    reduction_metadata = metadata.MetaData()
    reduction_metadata.load_a_layer_from_file(setup.red_dir,
                                              'pyDANDIA_metadata.fits',
                                              'reduction_parameters')

    star_catalog_file = os.path.join(TEST_DATA, 'star_catalog.fits')

    ref_star_catalog = catalog_utils.read_ref_star_catalog_file(
        star_catalog_file)

    log.info('Read in catalog of ' + str(len(ref_star_catalog)) + ' stars')

    psf_idx = 18
    psf_x = 189.283172607
    psf_y = 9.99084472656
    psf_size = 8.0

    stamp_dims = (20, 20)

    comps_list = psf.find_psf_companion_stars(setup, psf_idx, psf_x, psf_y,
                                              psf_size, ref_star_catalog, log,
                                              stamp_dims)

    assert len(comps_list) > 0

    for l in comps_list:
        log.info(repr(l))

    image_file = os.path.join(TEST_DATA,
                              'lsc1m005-fl15-20170701-0144-e91_cropped.fits')

    image = fits.getdata(image_file)

    corners = psf.calc_stamp_corners(psf_x,
                                     psf_y,
                                     stamp_dims[1],
                                     stamp_dims[0],
                                     image.shape[1],
                                     image.shape[0],
                                     over_edge=True)

    stamp = image[corners[2]:corners[3], corners[0]:corners[1]]

    log.info('Extracting PSF stamp image')

    fig = plt.figure(1)

    norm = visualization.ImageNormalize(stamp, \
                interval=visualization.ZScaleInterval())

    plt.imshow(stamp, origin='lower', cmap=plt.cm.viridis, norm=norm)

    x = []
    y = []
    for j in range(0, len(comps_list), 1):
        x.append(comps_list[j][1])
        y.append(comps_list[j][2])

    plt.plot(x, y, 'r+')

    plt.axis('equal')

    plt.xlabel('X pixel')

    plt.ylabel('Y pixel')

    plt.savefig(os.path.join(TEST_DATA, 'psf_companion_stars.png'))

    plt.close(1)

    logs.close_log(log)
示例#13
0
def test_subtract_companions_from_psf_stamps():
    """Function to test the function which removes companion stars from the 
    surrounds of a PSF star in a PSF star stamp image."""

    setup = pipeline_setup.pipeline_setup({'red_dir': TEST_DIR})

    log = logs.start_stage_log(cwd, 'test_subtract_companions')

    log.info(setup.summary())

    reduction_metadata = metadata.MetaData()
    reduction_metadata.load_a_layer_from_file(setup.red_dir,
                                              'pyDANDIA_metadata.fits',
                                              'reduction_parameters')

    star_catalog_file = os.path.join(TEST_DATA, 'star_catalog.fits')

    ref_star_catalog = catalog_utils.read_ref_star_catalog_file(
        star_catalog_file)

    log.info('Read in catalog of ' + str(len(ref_star_catalog)) + ' stars')

    image_file = os.path.join(TEST_DATA,
                              'lsc1m005-fl15-20170701-0144-e91_cropped.fits')

    image = fits.getdata(image_file)

    psf_idx = [248]
    psf_x = 257.656
    psf_y = 121.365

    stamp_centres = np.array([[psf_x, psf_y]])

    psf_size = 10.0
    stamp_dims = (20, 20)

    stamps = psf.cut_image_stamps(setup, image, stamp_centres, stamp_dims)

    if len(stamps) == 0:

        log.info(
            'ERROR: No PSF stamp images returned.  PSF stars too close to the edge?'
        )

    else:

        for i, s in enumerate(stamps):

            fig = plt.figure(1)

            norm = visualization.ImageNormalize(s.data, \
                            interval=visualization.ZScaleInterval())

            plt.imshow(s.data, origin='lower', cmap=plt.cm.viridis, norm=norm)

            plt.xlabel('X pixel')

            plt.ylabel('Y pixel')

            plt.axis('equal')

            plt.savefig(
                os.path.join(setup.red_dir,
                             'psf_star_stamp' + str(i) + '.png'))

            plt.close(1)

        psf_model = psf.Moffat2D()
        x_cen = psf_size + (psf_x - int(psf_x))
        y_cen = psf_size + (psf_x - int(psf_y))
        psf_radius = 8.0
        psf_params = [
            103301.241291, x_cen, y_cen, 226.750731765, 13004.8930993,
            103323.763627
        ]
        psf_model.update_psf_parameters(psf_params)

        sky_model = psf.ConstantBackground()
        sky_model.background_parameters.constant = 1345.0

        clean_stamps = psf.subtract_companions_from_psf_stamps(
            setup,
            reduction_metadata,
            log,
            ref_star_catalog,
            psf_idx,
            stamps,
            stamp_centres,
            psf_model,
            sky_model,
            diagnostics=True)

    logs.close_log(log)
示例#14
0
def show_image(image,
               percl=99,
               percu=None,
               is_mask=False,
               figsize=(6, 10),
               cmap='viridis',
               log=False,
               show_colorbar=True,
               show_ticks=True,
               fig=None,
               ax=None,
               input_ratio=None):
    """
    Show an image in matplotlib with some basic astronomically-appropriat stretching.

    Parameters
    ----------
    image
        The image to show
    percl : number
        The percentile for the lower edge of the stretch (or both edges if ``percu`` is None)
    percu : number or None
        The percentile for the upper edge of the stretch (or None to use ``percl`` for both)
    figsize : 2-tuple
        The size of the matplotlib figure in inches
    """
    if percu is None:
        percu = percl
        percl = 100 - percl

    if (fig is None and ax is not None) or (fig is not None and ax is None):
        raise ValueError('Must provide both "fig" and "ax" '
                         'if you provide one of them')
    elif fig is None and ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        if figsize is not None:
            # Rescale the fig size to match the image dimensions, roughly
            image_aspect_ratio = image.shape[0] / image.shape[1]
            figsize = (max(figsize) * image_aspect_ratio, max(figsize))
            print(figsize)

    # To preserve details we should *really* downsample correctly and
    # not rely on matplotlib to do it correctly for us (it won't).

    # So, calculate the size of the figure in pixels, block_reduce to
    # roughly that,and display the block reduced image.

    # Thanks, https://stackoverflow.com/questions/29702424/how-to-get-matplotlib-figure-size
    fig_size_pix = fig.get_size_inches() * fig.dpi

    ratio = (image.shape // fig_size_pix).max()

    if ratio < 1:
        ratio = 1

    ratio = input_ratio or ratio

    # Divide by the square of the ratio to keep the flux the same in the
    # reduced image
    reduced_data = block_reduce(image, ratio) / ratio**2

    # Of course, now that we have downsampled, the axis limits are changed to
    # match the smaller image size. Setting the extent will do the trick to
    # change the axis display back to showing the actual extent of the image.
    extent = [0, image.shape[1], 0, image.shape[0]]

    if log:
        stretch = aviz.LogStretch()
    else:
        stretch = aviz.LinearStretch()

    norm = aviz.ImageNormalize(reduced_data,
                               interval=aviz.AsymmetricPercentileInterval(
                                   percl, percu),
                               stretch=stretch)

    if is_mask:
        # The image is a mask in which pixels are zero or one. Set the image scale
        # limits appropriately.
        scale_args = dict(vmin=0, vmax=1)
    else:
        scale_args = dict(norm=norm)

    im = ax.imshow(reduced_data,
                   origin='lower',
                   cmap=cmap,
                   extent=extent,
                   aspect='equal',
                   **scale_args)

    if show_colorbar:
        # I haven't a clue why the fraction and pad arguments below work to make
        # the colorbar the same height as the image, but they do....unless the image
        # is wider than it is tall. Sticking with this for now anyway...
        # Thanks: https://stackoverflow.com/a/26720422/3486425
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, format='%2.0f')
        # In case someone in the future wants to improve this:
        # https://joseph-long.com/writing/colorbars/
        # https://stackoverflow.com/a/33505522/3486425
        # https://matplotlib.org/mpl_toolkits/axes_grid/users/overview.html#colorbar-whose-height-or-width-in-sync-with-the-master-axes

    if not show_ticks:
        ax.tick_params(labelbottom=False,
                       labelleft=False,
                       labelright=False,
                       labeltop=False)
示例#15
0
def find_bar_positions_from_image(imagefile, filtersize=5, plot=False,
                                  pixel_shim=5):
    '''Loop over all slits in the image and using the affine transformation
    determined by `fit_transforms`, select the Y pixel range over which this
    slit should be found.  Take a median filtered version of that image and
    determine the X direction gradient (derivative).  Then collapse it in
    the Y direction to form a 1D profile.
    
    Using the `find_bar_edges` method, determine the X pixel positions of
    each bar forming the slit.
    
    Convert those X pixel position to physical coordinates using the
    `pixel_to_physical` method and then call the `compare_to_csu_bar_state`
    method to determine the bar state.
    '''
    ## Get image from file
    imagefile = Path(imagefile).absolute()
    try:
        hdul = fits.open(imagefile)
        data = hdul[0].data
    except Error as e:
        log.error(e)
        raise
    # median X pixels only (preserve Y structure)
    medimage = ndimage.median_filter(data, size=(1, filtersize))
    
    bars = {}
    ypos = {}
    for slit in range(1,47):
        b1, b2 = slit_to_bars(slit)
        ## Determine y pixel range
        y1 = int(np.ceil((physical_to_pixel(np.array([(4.0, slit+0.5)])))[0][0][1])) + pixel_shim
        y2 = int(np.floor((physical_to_pixel(np.array([(270.4, slit-0.5)])))[0][0][1])) - pixel_shim
        ypos[b1] = [y1, y2]
        ypos[b2] = [y1, y2]
        gradx = np.gradient(medimage[y1:y2,:], axis=1)
        horizontal_profile = np.sum(gradx, axis=0)
        try:
            bars[b1], bars[b2] = find_bar_edges(horizontal_profile)
        except:
            print(f'Unable to fit bars: {b1}, {b2}')

    # Generate plot if called for
    if plot is True:
        plotfile = imagefile.with_name(f"{imagefile.stem}.png")
        log.info(f'Creating PNG image {plotfile}')
        if plotfile.exists(): plotfile.unlink()
        plt.figure(figsize=(16,16), dpi=300)
        norm = viz.ImageNormalize(data, interval=viz.PercentileInterval(99.9),
                                  stretch=viz.LinearStretch())
        plt.imshow(data, norm=norm, origin='lower', cmap='Greys')


        for bar in bars.keys():
#             plt.plot([0,2048], [ypos[bar][0], ypos[bar][0]], 'r-', alpha=0.1)
#             plt.plot([0,2048], [ypos[bar][1], ypos[bar][1]], 'r-', alpha=0.1)

            mms = np.linspace(4,270.4,2)
            slit = bar_to_slit(bar)
            pix = np.array([(physical_to_pixel(np.array([(mm, slit+0.5)])))[0][0] for mm in mms])
            plt.plot(pix.transpose()[0], pix.transpose()[1], 'g-', alpha=0.5)

            plt.plot([bars[bar],bars[bar]], ypos[bar], 'r-', alpha=0.75)
            offset = {0: -20, 1:+20}[bar % 2]
            plt.text(bars[bar]+offset, np.mean(ypos[bar]), bar,
                     fontsize=8, color='r', alpha=0.75,
                     horizontalalignment='center', verticalalignment='center')
        plt.savefig(str(plotfile), bbox_inches='tight')

    return bars
示例#16
0
rslts_thresh = uvcombine.feather_plot(almafn,
                                      lores=loresfn,
                                      lowresfwhm=loresfwhm,
                                      hires_threshold=0.0005,
                                      lores_threshold=0.001)
pl.figure(2).clf()
pl.imshow(combined.real + 0.01,
          origin='lower',
          interpolation='none',
          vmax=0.1,
          vmin=0.001,
          norm=pl.matplotlib.colors.LogNorm())
pl.axis((1620, 2300, 1842, 2750))

pl.figure(3).clf()
asinhnorm = lambda: visualization.ImageNormalize(visualization.AsinhStretch())

ax1 = pl.subplot(2, 3, 1)
im1 = ax1.imshow(combined.real + 0.01,
                 origin='lower',
                 interpolation='none',
                 vmax=0.1,
                 vmin=0.001,
                 norm=asinhnorm())
ax1.axis((1620, 2300, 1842, 2750))
pl.colorbar(mappable=im1)

ax2 = pl.subplot(2, 3, 2)
im2 = ax2.imshow(almafh.data,
                 origin='lower',
                 interpolation='none',
示例#17
0
def make_sed_plot(coordinate,
                  mgpsfile,
                  width=1 * u.arcmin,
                  surveys=Magpis.list_surveys(),
                  figure=None,
                  regname='GAL_031'):

    mgps_fh = fits.open(mgpsfile)[0]
    frame = wcs.utils.wcs_to_celestial_frame(wcs.WCS(mgps_fh.header))

    coordname = "{0:06.3f}_{1:06.3f}".format(coordinate.galactic.l.deg,
                                             coordinate.galactic.b.deg)

    mgps_cutout = Cutout2D(mgps_fh.data,
                           coordinate.transform_to(frame.name),
                           size=width * 2,
                           wcs=wcs.WCS(mgps_fh.header))
    print(
        f"Retrieving MAGPIS data for {coordname} ({coordinate.to_string()} {coordinate.frame.name})"
    )
    # we're treating 'width' as a radius elsewhere, here it's a full width
    images = {
        survey: getimg(coordinate, image_size=width * 2.75, survey=survey)
        for survey in surveys
    }
    images = {x: y for x, y in images.items() if y is not None}
    images['mgps'] = [mgps_cutout]

    regdir = os.path.join(paths.basepath, regname)
    if not os.path.exists(regdir):
        os.mkdir(regdir)
    higaldir = os.path.join(paths.basepath, regname, 'HiGalCutouts')
    if not os.path.exists(higaldir):
        os.mkdir(higaldir)
    if not any([
            os.path.exists(f"{higaldir}/{coordname}_{wavelength}.fits")
            for wavelength in map(int, HiGal.HIGAL_WAVELENGTHS.values())
    ]):
        print(
            f"Retrieving HiGal data for {coordname} ({coordinate.to_string()} {coordinate.frame.name})"
        )
        higal_ims = HiGal.get_images(coordinate, radius=width * 1.5)
        for hgim in higal_ims:
            images['HiGal{0}'.format(hgim[0].header['WAVELEN'])] = hgim
            hgim.writeto(
                f"{higaldir}/{coordname}_{hgim[0].header['WAVELEN']}.fits")
    else:
        print(
            f"Loading HiGal data from disk for {coordname} ({coordinate.to_string()} {coordinate.frame.name})"
        )
        for wavelength in map(int, HiGal.HIGAL_WAVELENGTHS.values()):
            hgfn = f"{higaldir}/{coordname}_{wavelength}.fits"
            if os.path.exists(hgfn):
                hgim = fits.open(hgfn)
                images['HiGal{0}'.format(hgim[0].header['WAVELEN'])] = hgim

    if 'gpsmsx2' in images:
        # redundant, save some space for a SED plot
        del images['gpsmsx2']
    if 'gps90' in images:
        # too low-res to be useful
        del images['gps90']

    if figure is None:
        figure = pl.figure(figsize=(15, 12))

    # coordinate stuff so images can be reprojected to same frame
    ww = mgps_cutout.wcs.celestial
    target_header = ww.to_header()
    del target_header['LONPOLE']
    del target_header['LATPOLE']
    mgps_pixscale = (wcs.utils.proj_plane_pixel_area(ww) * u.deg**2)**0.5
    target_header['NAXES'] = 2
    target_header['NAXIS1'] = target_header['NAXIS2'] = (
        width / mgps_pixscale).decompose().value
    #shape = [int((width / mgps_pixscale).decompose().value)]*2
    outframe = wcs.utils.wcs_to_celestial_frame(ww)
    crd_outframe = coordinate.transform_to(outframe)

    figure.clf()

    imagelist = sorted(images.items(), key=lambda x: wlmap[x[0]])

    #for ii, (survey,img) in enumerate(images.items()):
    for ii, (survey, img) in enumerate(imagelist):

        if hasattr(img[0], 'header'):
            inwcs = wcs.WCS(img[0].header).celestial
            pixscale_in = (wcs.utils.proj_plane_pixel_area(inwcs) *
                           u.deg**2)**0.5

            target_header['CDELT1'] = -pixscale_in.value
            target_header['CDELT2'] = pixscale_in.value
            target_header['CRVAL1'] = crd_outframe.spherical.lon.deg
            target_header['CRVAL2'] = crd_outframe.spherical.lat.deg
            axsize = int((width * 2.5 / pixscale_in).decompose().value)
            target_header['NAXIS1'] = target_header['NAXIS2'] = axsize
            target_header['CRPIX1'] = target_header['NAXIS1'] / 2
            target_header['CRPIX2'] = target_header['NAXIS2'] / 2
            shape_out = [axsize, axsize]

            print(
                f"Reprojecting {survey} to scale {pixscale_in} with shape {shape_out} and center {crd_outframe.to_string()}"
            )

            outwcs = wcs.WCS(target_header)

            new_img, _ = reproject.reproject_interp((img[0].data, inwcs),
                                                    target_header,
                                                    shape_out=shape_out)
        else:
            new_img = img[0].data
            outwcs = img[0].wcs
            pixscale_in = (wcs.utils.proj_plane_pixel_area(outwcs) *
                           u.deg**2)**0.5

        ax = figure.add_subplot(4, 5, ii + 1, projection=outwcs)
        ax.set_title("{0}: {1}".format(survey_titles[survey], wlmap[survey]))

        if not np.any(np.isfinite(new_img)):
            print(f"SKIPPING {survey}")
            continue

        norm = visualization.ImageNormalize(
            new_img,
            interval=visualization.PercentileInterval(99.95),
            stretch=visualization.AsinhStretch(),
        )

        ax.imshow(new_img, origin='lower', interpolation='none', norm=norm)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.xaxis.set_ticklabels('')
        ax.yaxis.set_ticklabels('')
        ax.coords[0].set_ticklabel_visible(False)
        ax.coords[1].set_ticklabel_visible(False)

        if 'GLON' in outwcs.wcs.ctype[0]:
            xpix, ypix = outwcs.wcs_world2pix(coordinate.galactic.l,
                                              coordinate.galactic.b, 0)
        else:
            xpix, ypix = outwcs.wcs_world2pix(coordinate.fk5.ra,
                                              coordinate.fk5.dec, 0)
        ax.set_xlim(xpix - (width / pixscale_in), xpix + (width / pixscale_in))
        ax.set_ylim(ypix - (width / pixscale_in), ypix + (width / pixscale_in))

        # scalebar = 1 arcmin

        ax.plot([
            xpix - width / pixscale_in + 5 * u.arcsec / pixscale_in,
            xpix - width / pixscale_in + 65 * u.arcsec / pixscale_in
        ], [
            ypix - width / pixscale_in + 5 * u.arcsec / pixscale_in,
            ypix - width / pixscale_in + 5 * u.arcsec / pixscale_in
        ],
                linestyle='-',
                linewidth=1,
                color='w')
        ax.plot(crd_outframe.spherical.lon.deg,
                crd_outframe.spherical.lat.deg,
                marker=((0, -10), (0, -4)),
                color='w',
                linestyle='none',
                markersize=20,
                markeredgewidth=0.5,
                transform=ax.get_transform('world'))
        ax.plot(crd_outframe.spherical.lon.deg,
                crd_outframe.spherical.lat.deg,
                marker=((4, 0), (10, 0)),
                color='w',
                linestyle='none',
                markersize=20,
                markeredgewidth=0.5,
                transform=ax.get_transform('world'))

    pl.tight_layout()
示例#18
0
            cubes['mask'].with_mask(include_mask).minimal_subcube()[0]
            if crop else cubes['mask'].with_mask(include_mask)[0])
    except AssertionError:
        # this implies there is no mask
        pass

    imgs['includemask'] = include_mask  # the mask applied to the cube

    # give up on the 'Slice' nature so we can change units
    imgs['model'] = imgs['model'].quantity * cubes[
        'image'].pixels_per_beam * u.pix / u.beam

    return imgs, cubes


asinhn = visualization.ImageNormalize(stretch=visualization.AsinhStretch())


def show(imgs,
         zoom=None,
         clear=True,
         norm=asinhn,
         imnames_toplot=('image', 'model', 'residual', 'mask'),
         **kwargs):

    if clear:
        pl.clf()

    if 'mask' not in imgs:
        imnames_toplot = list(imnames_toplot)
        imnames_toplot.remove('mask')
示例#19
0
def mask_stars(setup,
               ref_image,
               ref_star_catalog,
               psf_mask,
               diagnostics=False):
    """Function to create a mask for an image, returning a 2D image that has 
    a value of 0.0 in every pixel that falls within the PSF of a star, and
    a value of 1.0 everywhere else."""

    psf_size = psf_mask.shape[0]

    half_psf = int(psf_size / 2.0)

    pxmin = 0
    pxmax = psf_mask.shape[0]
    pymin = 0
    pymax = psf_mask.shape[1]

    star_mask = np.ones(ref_image.shape)

    for j in range(0, len(ref_star_catalog), 1):

        xstar = int(ref_star_catalog[j, 1])
        ystar = int(ref_star_catalog[j, 2])

        xmin = xstar - half_psf
        xmax = xstar + half_psf + 1
        ymin = ystar - half_psf
        ymax = ystar + half_psf + 1

        px1 = pxmin
        px2 = pxmax + 1
        py1 = pymin
        py2 = pymax + 1

        if xmin < 0:

            px1 = abs(xmin)
            xmin = 0

        if ymin < 0:

            py1 = abs(ymin)
            ymin = 0

        if xmax > ref_image.shape[0]:

            px2 = px2 - (xmax - ref_image.shape[0] + 1)
            xmax = ref_image.shape[0] + 1

        if ymax > ref_image.shape[1]:

            py2 = py2 - (ymax - ref_image.shape[1] + 1)
            ymax = ref_image.shape[1] + 1

        star_mask[ymin:ymax,xmin:xmax] = star_mask[ymin:ymax,xmin:xmax] + \
                                            psf_mask[py1:py2,px1:px2]

    idx_mask = np.where(star_mask > 1.0)
    star_mask[idx_mask] = np.nan

    idx = np.isnan(star_mask)
    idx = np.where(idx == False)

    star_mask[idx] = 0.0
    star_mask[idx_mask] = 1.0

    masked_ref_image = np.ma.masked_array(ref_image, mask=star_mask)

    if diagnostics == True:

        fig = plt.figure(2)

        norm = visualization.ImageNormalize(star_mask, \
                            interval=visualization.ZScaleInterval())

        plt.imshow(star_mask, origin='lower', cmap=plt.cm.viridis, norm=norm)

        plt.plot(ref_star_catalog[:, 1], ref_star_catalog[:, 2], 'r+')

        plt.xlabel('X pixel')

        plt.ylabel('Y pixel')

        plt.colorbar()

        plt.savefig(path.join(setup.red_dir, 'ref_star_mask.png'))

        plt.close(2)

        fig = plt.figure(3)

        norm = visualization.ImageNormalize(masked_ref_image, \
                            interval=visualization.ZScaleInterval())

        plt.imshow(masked_ref_image,
                   origin='lower',
                   cmap=plt.cm.viridis,
                   norm=norm)

        plt.plot(ref_star_catalog[:, 1], ref_star_catalog[:, 2], 'r+')

        plt.xlabel('X pixel')

        plt.ylabel('Y pixel')

        plt.colorbar()

        plt.savefig(path.join(setup.red_dir, 'masked_ref_image.png'))

        plt.close(3)

    return masked_ref_image
示例#20
0
 def set_normalization(self,
                       stretch=None,
                       interval=None,
                       stretchkwargs={},
                       intervalkwargs={},
                       perm_linear=None):
     if stretch is None:
         if self.stretch is None:
             stretch = 'linear'
         else:
             stretch = self.stretch
     if isinstance(stretch, str):
         print(stretch,
               ' '.join([f'{k}={v}' for k, v in stretchkwargs.items()]))
         if self.data is None:  #can not calculate objects yet
             self.stretch_kwargs = stretchkwargs
         else:
             kwargs = self.prepare_kwargs(
                 self.stretch_kws_defaults[stretch], self.stretch_kwargs,
                 stretchkwargs)
             if perm_linear is not None:
                 perm_linear_kwargs = self.prepare_kwargs(
                     self.stretch_kws_defaults['linear'], perm_linear)
                 print(
                     'linear', ' '.join([
                         f'{k}={v}' for k, v in perm_linear_kwargs.items()
                     ]))
                 if stretch == 'asinh':  # arg: a=0.1
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.AsinhStretch(**kwargs))
                 elif stretch == 'contrastbias':  # args: contrast, bias
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.ContrastBiasStretch(**kwargs))
                 elif stretch == 'histogram':
                     stretch = vis.CompositeStretch(
                         vis.HistEqStretch(self.data, **kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'log':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LogStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'powerdist':  # args: a=1000.0
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.PowerDistStretch(**kwargs))
                 elif stretch == 'power':  # args: a
                     stretch = vis.CompositeStretch(
                         vis.PowerStretch(**kwargs),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'sinh':  # args: a=0.33
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SinhStretch(**kwargs))
                 elif stretch == 'sqrt':
                     stretch = vis.CompositeStretch(
                         vis.SqrtStretch(),
                         vis.LinearStretch(**perm_linear_kwargs))
                 elif stretch == 'square':
                     stretch = vis.CompositeStretch(
                         vis.LinearStretch(**perm_linear_kwargs),
                         vis.SquaredStretch())
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
             else:
                 if stretch == 'linear':  # args: slope=1, intercept=0
                     stretch = vis.LinearStretch(**kwargs)
                 else:
                     raise ValueError('Unknown stretch:' + stretch)
     self.stretch = stretch
     if interval is None:
         if self.interval is None:
             interval = 'zscale'
         else:
             interval = self.interval
     if isinstance(interval, str):
         print(interval,
               ' '.join([f'{k}={v}' for k, v in intervalkwargs.items()]))
         kwargs = self.prepare_kwargs(self.interval_kws_defaults[interval],
                                      self.interval_kwargs, intervalkwargs)
         if self.data is None:
             self.interval_kwargs = intervalkwargs
         else:
             if interval == 'minmax':
                 interval = vis.MinMaxInterval()
             elif interval == 'manual':  # args: vmin, vmax
                 interval = vis.ManualInterval(**kwargs)
             elif interval == 'percentile':  # args: percentile, n_samples
                 interval = vis.PercentileInterval(**kwargs)
             elif interval == 'asymetric':  # args: lower_percentile, upper_percentile, n_samples
                 interval = vis.AsymmetricPercentileInterval(**kwargs)
             elif interval == 'zscale':  # args: nsamples=1000, contrast=0.25, max_reject=0.5, min_npixels=5, krej=2.5, max_iterations=5
                 interval = vis.ZScaleInterval(**kwargs)
             else:
                 raise ValueError('Unknown interval:' + interval)
     self.interval = interval
     if self.img is not None:
         self.img.set_norm(
             vis.ImageNormalize(self.data,
                                interval=self.interval,
                                stretch=self.stretch,
                                clip=True))
示例#21
0
import matplotlib.pyplot as plt
from astropy.io import fits
import astropy.visualization as viz

image_name = input("Please enter the name of the file : ")
hdul = fits.open(image_name)
hdul.info()
header_number = int(input("Enter Header number whose data  you want view : "))
image = hdul[header_number].data
hdul.close()
##stretching and normalizing using LogStretch() and MinMaxInterval() like in DS9
log_param = float(input("Enter base value for logrithmic stretch : "))
norm = viz.ImageNormalize(image,
                          interval=viz.MinMaxInterval(),
                          stretch=viz.LogStretch())
plt.imshow(image, cmap='gray')
plt.show()
import astropy.visualization as vis
from astropy.wcs import utils as wcsutils
import pylab as pl
import pyspeckit
import paths
from astropy import modeling
from astropy import stats

cube = SpectralCube.read(
    '/Users/adam/work/w51/alma/FITS/longbaseline/velo_cutouts/w51e2e_csv0_j2-1_r0.5_medsub.fits'
)
cs21cube = subcube = cube.spectral_slab(16 * u.km / u.s, 87 * u.km / u.s)[::-1]

norm = vis.ImageNormalize(
    subcube,
    interval=vis.ManualInterval(-0.002, 0.010),
    stretch=vis.AsinhStretch(),
)

pl.rcParams['font.size'] = 12

szinch = 18
fig = pl.figure(1, figsize=(szinch, szinch))
pl.pause(0.1)
for ii in range(5):
    fig.set_size_inches(szinch, szinch)
    pl.pause(0.1)
    try:
        assert np.all(fig.get_size_inches() == np.array([szinch, szinch]))
        break
    except AssertionError:
示例#23
0
from astropy.io import fits
import astropy.visualization as vis

img0 = msc.face()[:, :, 0]  # rgb image, take one channel
hl = fits.open(
    '/Users/luke/Dropbox/proj/timmy/data/phot/2020-04-01/TIC460205581-01-0196_Rc1_out.fit'
)
img1 = hl[0].data
hl.close()

for ix, img in enumerate([img0, img1]):

    if ix == 1:
        vmin, vmax = 10, int(1e4)
        norm = vis.ImageNormalize(vmin=vmin,
                                  vmax=vmax,
                                  stretch=vis.LogStretch(1000))
    else:
        norm = None

    f, axs = plt.subplots(nrows=2, ncols=2)
    # note: this image really should have origin='upper' (otherwise trashpanda is upside-down)
    # but this is to match fits image processing convention
    axs[0, 0].imshow(img, cmap=plt.cm.gray, origin='lower', norm=norm)
    axs[0, 0].set_title('shape: {}'.format(img.shape))

    dx, dy = 200, 50
    axs[1, 0].imshow(ti.integer_shift_img(img, dx, dy),
                     cmap=plt.cm.gray,
                     origin='lower',
                     norm=norm)
示例#24
0
    neb_subtracted[z, :, :] = neb_subtracted[z, :, :] - neb_spect[z]

if not os.path.exists(os.path.join(data_path, 'HH305E_nebsub.fits')):
    hdr = hdul[0].header
    now = dt.utcnow().strftime('%Y/%m/%d %H:%M:%S UT')
    hdr.set('HISTORY', f'Background subtracted {now}')
    hdu = fits.PrimaryHDU(data=neb_subtracted, header=hdr)
    hdu.writeto(os.path.join(data_path, 'HH305E_nebsub.fits'))

##-------------------------------------------------------------------------
## Plot mask of low H-beta emission
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.title('Sum of H-beta Bins')
norm = v.ImageNormalize(image,
                        interval=v.ManualInterval(vmin=image.min() - 5,
                                                  vmax=image.max() + 10),
                        stretch=v.LogStretch(10))
im = plt.imshow(image, origin='lower', norm=norm, cmap='Greys')
plt.colorbar(im)

plt.subplot(1, 2, 2)
plt.title('Nebular Emission Mask')
mimage = np.ma.MaskedArray(image)
mimage.mask = ~nmask
mimagef = np.ma.filled(mimage, fill_value=0)
norm = v.ImageNormalize(mimagef,
                        interval=v.ManualInterval(
                            vmin=image.min() - 5,
                            vmax=np.percentile(image, mask_pcnt) + 5),
                        stretch=v.LinearStretch())
im = plt.imshow(mimagef, origin='lower', norm=norm, cmap='Greys')
示例#25
0
def plot_image_and_lines(cube,
                         wavs,
                         xrange,
                         yrange,
                         Hbeta_ref=None,
                         title='',
                         filename=None,
                         include_OIII=False):

    zpix = np.arange(0, cube.shape[0])
    lambda_delta = 5
    hbeta_z = np.where((np.array(wavs) > h_beta_std.value-lambda_delta)\
                       & (np.array(wavs) < h_beta_std.value+lambda_delta))[0]
    image = np.mean(cube[min(hbeta_z):max(hbeta_z) + 1, :, :], axis=0)

    spect = [
        np.mean(cube[z, yrange[0]:yrange[1] + 1, xrange[0]:xrange[1] + 1])
        for z in zpix
    ]
    i_peak = spect.index(max(spect))

    background_0 = models.Polynomial1D(degree=2)
    H_beta_0 = models.Gaussian1D(amplitude=500,
                                 mean=4861,
                                 stddev=1.,
                                 bounds={
                                     'mean': (4855, 4865),
                                     'stddev': (0.1, 5)
                                 })
    OIII4959_0 = models.Gaussian1D(amplitude=100,
                                   mean=4959,
                                   stddev=1.,
                                   bounds={
                                       'mean': (4955, 4965),
                                       'stddev': (0.1, 5)
                                   })
    OIII5007_0 = models.Gaussian1D(amplitude=200,
                                   mean=5007,
                                   stddev=1.,
                                   bounds={
                                       'mean': (5002, 5012),
                                       'stddev': (0.1, 5)
                                   })
    fitter = fitting.LevMarLSQFitter()
    if include_OIII is True:
        model0 = background_0 + H_beta_0 + OIII4959_0 + OIII5007_0
    else:
        model0 = background_0 + H_beta_0

    model0.mean_1 = wavs[i_peak]
    model = fitter(model0, wavs, spect)
    residuals = np.array(spect - model(wavs))

    plt.figure(figsize=(20, 8))

    plt.subplot(1, 4, 1)
    plt.title(title)
    norm = v.ImageNormalize(image,
                            interval=v.MinMaxInterval(),
                            stretch=v.LogStretch(1))
    plt.imshow(image, origin='lower', norm=norm)
    region_x = [
        xrange[0] - 0.5, xrange[1] + 0.5, xrange[1] + 0.5, xrange[0] - 0.5,
        xrange[0] - 0.5
    ]
    region_y = [
        yrange[0] - 0.5, yrange[0] - 0.5, yrange[1] + 0.5, yrange[1] + 0.5,
        yrange[0] - 0.5
    ]
    plt.plot(region_x, region_y, 'r-', alpha=0.5, lw=2)

    plt.subplot(1, 4, 2)
    if Hbeta_ref is not None:
        Hbeta_velocity = (model.mean_1.value * u.Angstrom).to(
            u.km / u.s, equivalencies=u.doppler_optical(Hbeta_ref))
        title = f'H-beta ({model.mean_1.value:.1f} A, v={Hbeta_velocity.value:.1f} km/s)'
    else:
        title = f'H-beta ({model.mean_1.value:.1f} A, sigma={model.stddev_1.value:.3f} A)'
    plt.title(title)
    w = [l for l in np.arange(4856, 4866, 0.05)]
    if Hbeta_ref is not None:
        vs = [(l * u.Angstrom).to(
            u.km / u.s, equivalencies=u.doppler_optical(Hbeta_ref)).value
              for l in wavs]
        plt.plot(vs, spect, drawstyle='steps-mid', label='data')
        vs = [(l * u.Angstrom).to(
            u.km / u.s, equivalencies=u.doppler_optical(Hbeta_ref)).value
              for l in w]
        plt.plot(vs, model(w), 'r-', alpha=0.7, label='Fit')
        plt.xlabel('Velocity (km/s)')
        plt.xlim(-200, 200)
    else:
        plt.plot(wavs, spect, drawstyle='steps-mid', label='data')
        plt.plot(w, model(w), 'r-', alpha=0.7, label='Fit')
        plt.xlabel('Wavelength (angstroms)')
        plt.xlim(4856, 4866)
    plt.grid()
    plt.ylabel('Flux')
    plt.legend(loc='best')

    plt.subplot(1, 4, 3)
    if include_OIII is True:
        title = f'OIII 4959 ({model.mean_2.value:.1f} A, sigma={model.stddev_2.value:.3f} A)'
    else:
        title = f'OIII 4959'
    plt.title(title)
    plt.plot(wavs, spect, drawstyle='steps-mid', label='data')
    w = [l for l in np.arange(4954, 4964, 0.05)]
    plt.plot(w, model(w), 'r-', alpha=0.7, label='Fit')
    plt.xlabel('Wavelength (angstroms)')
    plt.ylabel('Flux')
    plt.legend(loc='best')
    plt.xlim(4954, 4964)

    plt.subplot(1, 4, 4)
    if include_OIII is True:
        title = f'OIII 5007 ({model.mean_3.value:.1f} A, sigma={model.stddev_3.value:.3f} A)'
    else:
        title = f'OIII 5007'
    plt.title(title)
    plt.plot(wavs, spect, drawstyle='steps-mid', label='data')
    w = [l for l in np.arange(5002, 5012, 0.05)]
    plt.plot(w, model(w), 'r-', alpha=0.7, label='Fit')
    plt.xlabel('Wavelength (angstroms)')
    plt.ylabel('Flux')
    plt.legend(loc='best')
    plt.xlim(5002, 5012)

    if filename is not None:
        plt.savefig(filename, bbox_inches='tight', pad_inches=0.10)
    else:
        plt.show()

    return spect, model
示例#26
0
文件: utils.py 项目: ntejos/pyntejos
def plot_fits(img,
              header,
              figsize=(10, 10),
              fontsize=16,
              levels=(None, None),
              lognorm=False,
              title=None,
              show=True,
              cmap="viridis"):
    """
    Show a fits image. (c) Sunil Sumha's code
    Parameters
    ----------
    img: np.ndarray
        Image data
    header: fits.header.Header
        Fits image header
    figsize: tuple of ints, optional
        Size of figure to be displayed (x,y)
    levels: tuple of floats, optional
        Minimum and maximum pixel values
        for visualisation.
    lognorm: bool, optional
        If true, the visualisation is log
        stretched.
    title: str, optional
        Title of the image
    show: bool, optional
        If true, displays the image.
        Else, returns the fig, ax
    cmap: str or pyplot cmap, optional
        Defaults to viridis

    Returns
    -------
    None if show is False. fig, ax if True
    """
    from astropy.wcs import WCS
    from astropy.stats import sigma_clipped_stats
    from astropy import visualization as vis

    plt.rcParams['font.size'] = fontsize
    wcs = WCS(header)

    _, median, sigma = sigma_clipped_stats(img)

    assert len(levels) == 2, "Invalid levels. Use this format: (vmin,vmax)"
    vmin, vmax = levels

    if vmin is None:
        vmin = median
    if vmax is not None:
        if vmin > vmax:
            vmin = vmax - 10 * sigma
            warnings.warn(
                "levels changed to ({:f},{:f}) because input vmin waz greater than vmax"
                .format(vmin, vmax))
    else:
        vmax = median + 10 * sigma

    fig = plt.figure(figsize=figsize)
    ax = plt.subplot(projection=wcs)

    if lognorm:
        ax.imshow(img,
                  vmax=vmax,
                  vmin=vmin,
                  norm=vis.ImageNormalize(stretch=vis.LogStretch()),
                  cmap=cmap)
    else:
        ax.imshow(img, vmax=vmax, vmin=vmin, cmap=cmap)
    ax.set_xlabel("RA")
    ax.set_ylabel("Dec")
    ax.set_title(title)
    if show:
        plt.show()
    else:
        return fig, ax
    def makeSlitIllum(self, adinputs=None, **params):
        """
        Makes the processed Slit Illumination Function by binning a 2D
        spectrum along the dispersion direction, fitting a smooth function
        for each bin, fitting a smooth 2D model, and reconstructing the 2D
        array using this last model.

        Its implementation based on the IRAF's `noao.twodspec.longslit.illumination`
        task following the algorithm described in [Valdes, 1968].

        It expects an input calibration image to be an a dispersed image of the
        slit without illumination problems (e.g, twilight flat). The spectra is
        not required to be smooth in wavelength and may contain strong emission
        and absorption lines. The image should contain a `.mask` attribute in
        each extension, and it is expected to be overscan and bias corrected.

        Parameters
        ----------
        adinputs : list
            List of AstroData objects containing the dispersed image of the
            slit of a source free of illumination problems. The data needs to
            have been overscan and bias corrected and is expected to have a
            Data Quality mask.
        bins : {None, int}, optional
            Total number of bins across the dispersion axis. If None,
            the number of bins will match the number of extensions on each
            input AstroData object. It it is an int, it will create N bins
            with the same size.
        border : int, optional
            Border size that is added on every edge of the slit illumination
            image before cutting it down to the input AstroData frame.
        smooth_order : int, optional
            Order of the spline that is used in each bin fitting to smooth
            the data (Default: 3)
        x_order : int, optional
            Order of the x-component in the Chebyshev2D model used to
            reconstruct the 2D data from the binned data.
        y_order : int, optional
            Order of the y-component in the Chebyshev2D model used to
            reconstruct the 2D data from the binned data.

        Return
        ------
        List of AstroData : containing an AstroData with the Slit Illumination
            Response Function for each of the input object.

        References
        ----------
        .. [Valdes, 1968] Francisco Valdes "Reduction Of Long Slit Spectra With
           IRAF", Proc. SPIE 0627, Instrumentation in Astronomy VI,
           (13 October 1986); https://doi.org/10.1117/12.968155
        """
        log = self.log
        log.debug(gt.log_message("primitive", self.myself(), "starting"))
        timestamp_key = self.timestamp_keys[self.myself()]

        suffix = params["suffix"]
        bins = params["bins"]
        border = params["border"]
        debug_plot = params["debug_plot"]
        smooth_order = params["smooth_order"]
        cheb2d_x_order = params["x_order"]
        cheb2d_y_order = params["y_order"]

        ad_outputs = []
        for ad in adinputs:

            if len(ad) > 1 and "mosaic" not in ad[0].wcs.available_frames:

                log.info('Add "mosaic" gWCS frame to input data')
                geotable = import_module('.geometry_conf', self.inst_lookups)

                # deepcopy prevents modifying input `ad` inplace
                ad = transform.add_mosaic_wcs(deepcopy(ad), geotable)

                log.info("Temporarily mosaicking multi-extension file")
                mosaicked_ad = transform.resample_from_wcs(
                    ad,
                    "mosaic",
                    attributes=None,
                    order=1,
                    process_objcat=False)

            else:

                log.info('Input data already has one extension and has a '
                         '"mosaic" frame.')

                # deepcopy prevents modifying input `ad` inplace
                mosaicked_ad = deepcopy(ad)

            log.info("Transposing data if needed")
            dispaxis = 2 - mosaicked_ad[0].dispersion_axis()  # python sense
            should_transpose = dispaxis == 1

            data, mask, variance = _transpose_if_needed(
                mosaicked_ad[0].data,
                mosaicked_ad[0].mask,
                mosaicked_ad[0].variance,
                transpose=should_transpose)

            log.info("Masking data")
            data = np.ma.masked_array(data, mask=mask)
            variance = np.ma.masked_array(variance, mask=mask)
            std = np.sqrt(variance)  # Easier to work with

            log.info("Creating bins for data and variance")
            height = data.shape[0]
            width = data.shape[1]

            if bins is None:
                nbins = max(len(ad), 12)
                bin_limits = np.linspace(0, height, nbins + 1, dtype=int)
            elif isinstance(bins, int):
                nbins = bins
                bin_limits = np.linspace(0, height, nbins + 1, dtype=int)
            else:
                # ToDo: Handle input bins as array
                raise TypeError("Expected None or Int for `bins`. "
                                "Found: {}".format(type(bins)))

            bin_top = bin_limits[1:]
            bin_bot = bin_limits[:-1]
            binned_data = np.zeros_like(data)
            binned_std = np.zeros_like(std)

            log.info("Smooth binned data and variance, and normalize them by "
                     "smoothed central value")
            for bin_idx, (b0, b1) in enumerate(zip(bin_bot, bin_top)):

                rows = np.arange(width)

                avg_data = np.ma.mean(data[b0:b1], axis=0)
                model_1d_data = astromodels.UnivariateSplineWithOutlierRemoval(
                    rows, avg_data, order=smooth_order)

                avg_std = np.ma.mean(std[b0:b1], axis=0)
                model_1d_std = astromodels.UnivariateSplineWithOutlierRemoval(
                    rows, avg_std, order=smooth_order)

                slit_central_value = model_1d_data(rows)[width // 2]
                binned_data[b0:b1] = model_1d_data(rows) / slit_central_value
                binned_std[b0:b1] = model_1d_std(rows) / slit_central_value

            log.info("Reconstruct 2D mosaicked data")
            bin_center = np.array(0.5 * (bin_bot + bin_top), dtype=int)
            cols_fit, rows_fit = np.meshgrid(np.arange(width), bin_center)

            fitter = fitting.SLSQPLSQFitter()
            model_2d_init = models.Chebyshev2D(x_degree=cheb2d_x_order,
                                               x_domain=(0, width),
                                               y_degree=cheb2d_y_order,
                                               y_domain=(0, height))

            model_2d_data = fitter(model_2d_init, cols_fit, rows_fit,
                                   binned_data[rows_fit, cols_fit])

            model_2d_std = fitter(model_2d_init, cols_fit, rows_fit,
                                  binned_std[rows_fit, cols_fit])

            rows_val, cols_val = \
                np.mgrid[-border:height+border, -border:width+border]

            slit_response_data = model_2d_data(cols_val, rows_val)
            slit_response_mask = np.pad(
                mask, border, mode='edge')  # ToDo: any update to the mask?
            slit_response_std = model_2d_std(cols_val, rows_val)
            slit_response_var = slit_response_std**2

            del cols_fit, cols_val, rows_fit, rows_val

            _data, _mask, _variance = _transpose_if_needed(
                slit_response_data,
                slit_response_mask,
                slit_response_var,
                transpose=dispaxis == 1)

            log.info("Update slit response data and data_section")
            slit_response_ad = deepcopy(mosaicked_ad)
            slit_response_ad[0].data = _data
            slit_response_ad[0].mask = _mask
            slit_response_ad[0].variance = _variance

            if "mosaic" in ad[0].wcs.available_frames:

                log.info(
                    "Map coordinates between slit function and mosaicked data"
                )  # ToDo: Improve message?
                slit_response_ad = _split_mosaic_into_extensions(
                    ad, slit_response_ad, border_size=border)

            elif len(ad) == 1:

                log.info("Trim out borders")

                slit_response_ad[0].data = \
                    slit_response_ad[0].data[border:-border, border:-border]
                slit_response_ad[0].mask = \
                    slit_response_ad[0].mask[border:-border, border:-border]
                slit_response_ad[0].variance = \
                    slit_response_ad[0].variance[border:-border, border:-border]

            log.info("Update metadata and filename")
            gt.mark_history(slit_response_ad,
                            primname=self.myself(),
                            keyword=timestamp_key)

            slit_response_ad.update_filename(suffix=suffix, strip=True)
            ad_outputs.append(slit_response_ad)

            # Plotting ------
            if debug_plot:

                log.info("Creating plots")
                palette = copy(plt.cm.cividis)
                palette.set_bad('r', 0.75)

                norm = vis.ImageNormalize(data[~data.mask],
                                          stretch=vis.LinearStretch(),
                                          interval=vis.PercentileInterval(97))

                fig = plt.figure(num="Slit Response from MEF - {}".format(
                    ad.filename),
                                 figsize=(12, 9),
                                 dpi=110)

                gs = gridspec.GridSpec(nrows=2, ncols=3, figure=fig)

                # Display raw mosaicked data and its bins ---
                ax1 = fig.add_subplot(gs[0, 0])
                im1 = ax1.imshow(data,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=norm.vmin,
                                 vmax=norm.vmax)

                ax1.set_title("Mosaicked Data\n and Spectral Bins",
                              fontsize=10)
                ax1.set_xlim(-1, data.shape[1])
                ax1.set_xticks([])
                ax1.set_ylim(-1, data.shape[0])
                ax1.set_yticks(bin_center)
                ax1.tick_params(axis=u'both', which=u'both', length=0)

                ax1.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax1.spines[s].set_visible(False) for s in ax1.spines]
                _ = [ax1.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax1)
                cax1 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im1, cax=cax1)

                # Display non-smoothed bins ---
                ax2 = fig.add_subplot(gs[0, 1])
                im2 = ax2.imshow(binned_data, cmap=palette, origin='lower')

                ax2.set_title("Binned, smoothed\n and normalized data ",
                              fontsize=10)
                ax2.set_xlim(0, data.shape[1])
                ax2.set_xticks([])
                ax2.set_ylim(0, data.shape[0])
                ax2.set_yticks(bin_center)
                ax2.tick_params(axis=u'both', which=u'both', length=0)

                ax2.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax2.spines[s].set_visible(False) for s in ax2.spines]
                _ = [ax2.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax2)
                cax2 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im2, cax=cax2)

                # Display reconstructed slit response ---
                vmin = slit_response_data.min()
                vmax = slit_response_data.max()

                ax3 = fig.add_subplot(gs[1, 0])
                im3 = ax3.imshow(slit_response_data,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=vmin,
                                 vmax=vmax)

                ax3.set_title("Reconstructed\n Slit response", fontsize=10)
                ax3.set_xlim(0, data.shape[1])
                ax3.set_xticks([])
                ax3.set_ylim(0, data.shape[0])
                ax3.set_yticks([])
                ax3.tick_params(axis=u'both', which=u'both', length=0)
                _ = [ax3.spines[s].set_visible(False) for s in ax3.spines]

                divider = make_axes_locatable(ax3)
                cax3 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im3, cax=cax3)

                # Display extensions ---
                ax4 = fig.add_subplot(gs[1, 1])
                ax4.set_xticks([])
                ax4.set_yticks([])
                _ = [ax4.spines[s].set_visible(False) for s in ax4.spines]

                sub_gs4 = gridspec.GridSpecFromSubplotSpec(nrows=len(ad),
                                                           ncols=1,
                                                           subplot_spec=gs[1,
                                                                           1],
                                                           hspace=0.03)

                # The [::-1] is needed to put the fist extension in the bottom
                for i, ext in enumerate(slit_response_ad[::-1]):

                    ext_data, ext_mask, ext_variance = _transpose_if_needed(
                        ext.data,
                        ext.mask,
                        ext.variance,
                        transpose=dispaxis == 1)

                    ext_data = np.ma.masked_array(ext_data, mask=ext_mask)

                    sub_ax = fig.add_subplot(sub_gs4[i])

                    im4 = sub_ax.imshow(ext_data,
                                        origin="lower",
                                        vmin=vmin,
                                        vmax=vmax,
                                        cmap=palette)

                    sub_ax.set_xlim(0, ext_data.shape[1])
                    sub_ax.set_xticks([])
                    sub_ax.set_ylim(0, ext_data.shape[0])
                    sub_ax.set_yticks([ext_data.shape[0] // 2])

                    sub_ax.set_yticklabels(
                        ["Ext {}".format(len(slit_response_ad) - i - 1)],
                        fontsize=6)

                    _ = [
                        sub_ax.spines[s].set_visible(False)
                        for s in sub_ax.spines
                    ]

                    if i == 0:
                        sub_ax.set_title(
                            "Multi-extension\n Slit Response Function")

                divider = make_axes_locatable(ax4)
                cax4 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im4, cax=cax4)

                # Display Signal-To-Noise Ratio ---
                snr = data / np.sqrt(variance)

                norm = vis.ImageNormalize(snr[~snr.mask],
                                          stretch=vis.LinearStretch(),
                                          interval=vis.PercentileInterval(97))

                ax5 = fig.add_subplot(gs[0, 2])

                im5 = ax5.imshow(snr,
                                 cmap=palette,
                                 origin='lower',
                                 vmin=norm.vmin,
                                 vmax=norm.vmax)

                ax5.set_title("Mosaicked Data SNR", fontsize=10)
                ax5.set_xlim(-1, data.shape[1])
                ax5.set_xticks([])
                ax5.set_ylim(-1, data.shape[0])
                ax5.set_yticks(bin_center)
                ax5.tick_params(axis=u'both', which=u'both', length=0)

                ax5.set_yticklabels(
                    ["Bin {}".format(i) for i in range(len(bin_center))],
                    fontsize=6)

                _ = [ax5.spines[s].set_visible(False) for s in ax5.spines]
                _ = [ax5.axhline(b, c='w', lw=0.5) for b in bin_limits]

                divider = make_axes_locatable(ax5)
                cax5 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im5, cax=cax5)

                # Display Signal-To-Noise Ratio of Slit Illumination ---
                slit_response_snr = np.ma.masked_array(
                    slit_response_data / np.sqrt(slit_response_var),
                    mask=slit_response_mask)

                ax6 = fig.add_subplot(gs[1, 2])

                im6 = ax6.imshow(slit_response_snr,
                                 origin="lower",
                                 vmin=norm.vmin,
                                 vmax=norm.vmax,
                                 cmap=palette)

                ax6.set_xlim(0, slit_response_snr.shape[1])
                ax6.set_xticks([])
                ax6.set_ylim(0, slit_response_snr.shape[0])
                ax6.set_yticks([])
                ax6.set_title("Reconstructed\n Slit Response SNR")

                _ = [ax6.spines[s].set_visible(False) for s in ax6.spines]

                divider = make_axes_locatable(ax6)
                cax6 = divider.append_axes("right", size="5%", pad=0.05)
                plt.colorbar(im6, cax=cax6)

                # Save plots ---
                fig.tight_layout(rect=[0, 0, 0.95, 1], pad=0.5)
                fname = slit_response_ad.filename.replace(".fits", ".png")
                log.info("Saving plots to {}".format(fname))
                plt.savefig(fname)

        return ad_outputs
示例#28
0
def plot_image_fit_residuals(fig, image, fit, residuals=None, percentile=95.0):
    """
	Make a figure with three subplots showing the image, the fit and the
	residuals. The image and the fit are shown with logarithmic scaling and a
	common colorbar. The residuals are shown with linear scaling and a separate
	colorbar.

	Parameters:
		fig (fig object): Figure object in which to make the subplots.
		image (2D array): Image numpy array.
		fit (2D array): Fitted image numpy array.
		residuals (2D array, optional): Fitted image subtracted from image numpy array.

	Returns:
		list: List with Matplotlib subplot axes objects for each subplot.
	"""

    if residuals is None:
        residuals = image - fit

    # Calculate common normalization for the first two subplots:
    vmin_image, vmax_image = viz.PercentileInterval(percentile).get_limits(
        image)
    vmin_fit, vmax_fit = viz.PercentileInterval(percentile).get_limits(fit)
    vmin = np.nanmin([vmin_image, vmin_fit])
    vmax = np.nanmax([vmax_image, vmax_fit])
    norm = viz.ImageNormalize(vmin=vmin, vmax=vmax, stretch=viz.LogStretch())

    # Add subplot with the image:
    ax1 = fig.add_subplot(131)
    im1 = plot_image(image, ax=ax1, scale=norm, cbar=None, title='Image')

    # Add subplot with the fit:
    ax2 = fig.add_subplot(132)
    plot_image(fit, ax=ax2, scale=norm, cbar=None, title='PSF fit')

    # Calculate the normalization for the third subplot:
    vmin, vmax = viz.PercentileInterval(percentile).get_limits(residuals)
    v = np.max(np.abs([vmin, vmax]))

    # Add subplot with the residuals:
    ax3 = fig.add_subplot(133)
    im3 = plot_image(residuals,
                     ax=ax3,
                     scale='linear',
                     cmap='seismic',
                     vmin=-v,
                     vmax=v,
                     cbar=None,
                     title='Residuals')

    # Make the common colorbar for image and fit subplots:
    cbar_ax12 = fig.add_axes([0.125, 0.2, 0.494, 0.03])
    fig.colorbar(im1, cax=cbar_ax12, orientation='horizontal')

    # Make the colorbar for the residuals subplot:
    cbar_ax3 = fig.add_axes([0.7, 0.2, 0.205, 0.03])
    fig.colorbar(im3, cax=cbar_ax3, orientation='horizontal')

    # Add more space between subplots:
    plt.subplots_adjust(wspace=0.4, hspace=0.4)

    return [ax1, ax2, ax3]
示例#29
0
def make_analysis_forms(
        basepath="/orange/adamginsburg/web/secure/ALMA-IMF/October31Release/",
        base_form_url="https://docs.google.com/forms/d/e/1FAIpQLSczsBdB3Am4znOio2Ky5GZqAnRYDrYTD704gspNu7fAMm2-NQ/viewform?embedded=true",
        dontskip_noresid=False):
    import glob
    from diagnostic_images import load_images, show as show_images
    from astropy import visualization
    import pylab as pl

    savepath = f'{basepath}/quicklooks'

    try:
        os.mkdir(savepath)
    except:
        pass

    filedict = {
        (field, band, config, robust, selfcal): glob.glob(
            f"{field}/B{band}/{imtype}{field}*_B{band}_*_{config}_robust{robust}*selfcal{selfcal}*.image.tt0*.fits"
        )
        for field in
        "G008.67 G337.92 W43-MM3 G328.25 G351.77 G012.80 G327.29 W43-MM1 G010.62 W51-IRS2 W43-MM2 G333.60 G338.93 W51-E G353.41"
        .split() for band in (3, 6)
        #for config in ('7M12M', '12M')
        for config in ('12M', )
        #for robust in (-2, 0, 2)
        for robust in (0, ) for selfcal in ("", ) + tuple(range(0, 9))
        for imtype in (('', ) if 'October31' in basepath else ('cleanest/',
                                                               'bsens/'))
    }
    badfiledict = {key: val for key, val in filedict.items() if len(val) == 1}
    print(f"Bad files: {badfiledict}")
    filedict = {key: val for key, val in filedict.items() if len(val) > 1}
    filelist = [key + (fn, ) for key, val in filedict.items() for fn in val]

    prev = 'index.html'

    flist = []

    #for field in "G008.67 G337.92 W43-MM3 G328.25 G351.77 G012.80 G327.29 W43-MM1 G010.62 W51-IRS2 W43-MM2 G333.60 G338.93 W51-E G353.41".split():
    ##for field in ("G333.60",):
    #    for band in (3,6):
    #        for config in ('7M12M', '12M'):
    #            for robust in (-2, 0, 2):

    #                # for not all-in-the-same-place stuff
    #                fns = [x for x in glob.glob(f"{field}/B{band}/{field}*_B{band}_*_{config}_robust{robust}*selfcal[0-9]*.image.tt0*.fits") ]

    #                for fn in fns:
    for ii, (field, band, config, robust, selfcal, fn) in enumerate(filelist):

        image = fn
        basename, suffix = image.split(".image.tt0")
        if 'diff' in suffix or 'bsens-cleanest' in suffix:
            continue
        outname = basename.split("/")[-1]

        if prev == outname + ".html":
            print(
                f"{ii}: {(field, band, config, robust, fn)} yielded the same prev "
                f"{prev} as last time, skipping.")
            continue

        jj = 1
        while jj < len(filelist):
            if ii + jj < len(filelist):
                next_ = filelist[ii + jj][5].split(".image.tt0")[0].split(
                    "/")[-1] + ".html"
            else:
                next_ = "index.html"

            if next_ == outname + ".html":
                jj = jj + 1
            else:
                break

        assert next_ != outname + ".html"

        try:
            with warnings.catch_warnings():
                warnings.filterwarnings('ignore')
                print(f"{ii}: {(field, band, config, robust, fn, selfcal)}"
                      f" basename='{basename}', suffix='{suffix}'")
                imgs, cubes = load_images(basename, suffix=suffix)
        except KeyError as ex:
            print(ex)
            raise
        except Exception as ex:
            print(f"EXCEPTION: {type(ex)}: {str(ex)}")
            raise
            continue
        norm = visualization.ImageNormalize(
            stretch=visualization.AsinhStretch(),
            interval=visualization.PercentileInterval(99.95))
        # set the scaling based on one of these...
        # (this call inplace-modifies logn, according to the docs)
        if 'residual' in imgs:
            norm(imgs['residual'][imgs['residual'] == imgs['residual']])
            imnames_toplot = ('mask', 'model', 'image', 'residual')
        elif 'image' in imgs and dontskip_noresid:
            imnames_toplot = (
                'image',
                'mask',
            )
            norm(imgs['image'][imgs['image'] == imgs['image']])
        else:
            print(
                f"Skipped {fn} because no image OR residual was found.  imgs.keys={imgs.keys()}"
            )
            continue
        pl.close(1)
        pl.figure(1, figsize=(14, 6))
        show_images(imgs, norm=norm, imnames_toplot=imnames_toplot)

        pl.savefig(f"{savepath}/{outname}.png", dpi=150, bbox_inches='tight')

        metadata = {
            'field': field,
            'band': band,
            'selfcal': selfcal,  #get_selfcal_number(basename),
            'array': config,
            'robust': robust,
            'finaliter': 'finaliter' in fn,
        }
        make_quicklook_analysis_form(filename=outname,
                                     metadata=metadata,
                                     savepath=savepath,
                                     prev=prev,
                                     next_=next_,
                                     base_form_url=base_form_url)
        metadata['outname'] = outname
        metadata['suffix'] = suffix
        if robust == 0:
            # only keep robust=0 for simplicity
            flist.append(metadata)
        prev = outname + ".html"

    #make_rand_html(savepath)
    make_index(savepath, flist)

    return flist
示例#30
0
def plot_image(image,
               ax=None,
               scale='log',
               cmap=None,
               origin='lower',
               xlabel=None,
               ylabel=None,
               cbar=None,
               clabel='Flux ($e^{-}s^{-1}$)',
               cbar_ticks=None,
               cbar_ticklabels=None,
               cbar_pad=None,
               cbar_size='5%',
               title=None,
               percentile=95.0,
               vmin=None,
               vmax=None,
               offset_axes=None,
               color_bad='k',
               **kwargs):
    """
	Utility function to plot a 2D image.

	Parameters:
		image (2d array): Image data.
		ax (matplotlib.pyplot.axes, optional): Axes in which to plot.
			Default (None) is to use current active axes.
		scale (str or :py:class:`astropy.visualization.ImageNormalize` object, optional):
			Normalization used to stretch the colormap.
			Options: ``'linear'``, ``'sqrt'``, ``'log'``, ``'asinh'``, ``'histeq'``, ``'sinh'``
			and ``'squared'``.
			Can also be a :py:class:`astropy.visualization.ImageNormalize` object.
			Default is ``'log'``.
		origin (str, optional): The origin of the coordinate system.
		xlabel (str, optional): Label for the x-axis.
		ylabel (str, optional): Label for the y-axis.
		cbar (string, optional): Location of color bar.
			Choises are ``'right'``, ``'left'``, ``'top'``, ``'bottom'``.
			Default is not to create colorbar.
		clabel (str, optional): Label for the color bar.
		cbar_size (float, optional): Fractional size of colorbar compared to axes. Default=0.03.
		cbar_pad (float, optional): Padding between axes and colorbar.
		title (str or None, optional): Title for the plot.
		percentile (float, optional): The fraction of pixels to keep in color-trim.
			If single float given, the same fraction of pixels is eliminated from both ends.
			If tuple of two floats is given, the two are used as the percentiles.
			Default=95.
		cmap (matplotlib colormap, optional): Colormap to use. Default is the ``Blues`` colormap.
		vmin (float, optional): Lower limit to use for colormap.
		vmax (float, optional): Upper limit to use for colormap.
		color_bad (str, optional): Color to apply to bad pixels (NaN). Default is black.
		kwargs (dict, optional): Keyword arguments to be passed to :py:func:`matplotlib.pyplot.imshow`.

	Returns:
		:py:class:`matplotlib.image.AxesImage`: Image from returned
			by :py:func:`matplotlib.pyplot.imshow`.

	.. codeauthor:: Rasmus Handberg <*****@*****.**>
	"""

    logger = logging.getLogger(__name__)

    # Backward compatible settings:
    make_cbar = kwargs.pop('make_cbar', None)
    if make_cbar:
        raise FutureWarning("'make_cbar' is deprecated. Use 'cbar' instead.")
        if not cbar:
            cbar = make_cbar

    # Special treatment for boolean arrays:
    if isinstance(image, np.ndarray) and image.dtype == 'bool':
        if vmin is None: vmin = 0
        if vmax is None: vmax = 1
        if cbar_ticks is None: cbar_ticks = [0, 1]
        if cbar_ticklabels is None: cbar_ticklabels = ['False', 'True']

    # Calculate limits of color scaling:
    interval = None
    if vmin is None or vmax is None:
        if allnan(image):
            logger.warning("Image is all NaN")
            vmin = 0
            vmax = 1
            if cbar_ticks is None:
                cbar_ticks = []
            if cbar_ticklabels is None:
                cbar_ticklabels = []
        elif isinstance(percentile, (list, tuple, np.ndarray)):
            interval = viz.AsymmetricPercentileInterval(
                percentile[0], percentile[1])
        else:
            interval = viz.PercentileInterval(percentile)

    # Create ImageNormalize object with extracted limits:
    if scale in ('log', 'linear', 'sqrt', 'asinh', 'histeq', 'sinh',
                 'squared'):
        if scale == 'log':
            stretch = viz.LogStretch()
        elif scale == 'linear':
            stretch = viz.LinearStretch()
        elif scale == 'sqrt':
            stretch = viz.SqrtStretch()
        elif scale == 'asinh':
            stretch = viz.AsinhStretch()
        elif scale == 'histeq':
            stretch = viz.HistEqStretch(image[np.isfinite(image)])
        elif scale == 'sinh':
            stretch = viz.SinhStretch()
        elif scale == 'squared':
            stretch = viz.SquaredStretch()

        # Create ImageNormalize object. Very important to use clip=False if the image contains
        # NaNs, otherwise NaN points will not be plotted correctly.
        norm = viz.ImageNormalize(data=image[np.isfinite(image)],
                                  interval=interval,
                                  vmin=vmin,
                                  vmax=vmax,
                                  stretch=stretch,
                                  clip=not anynan(image))

    elif isinstance(scale, (viz.ImageNormalize, matplotlib.colors.Normalize)):
        norm = scale
    else:
        raise ValueError("scale {} is not available.".format(scale))

    if offset_axes:
        extent = (offset_axes[0] - 0.5, offset_axes[0] + image.shape[1] - 0.5,
                  offset_axes[1] - 0.5, offset_axes[1] + image.shape[0] - 0.5)
    else:
        extent = (-0.5, image.shape[1] - 0.5, -0.5, image.shape[0] - 0.5)

    if ax is None:
        ax = plt.gca()

    # Set up the colormap to use. If a bad color is defined,
    # add it to the colormap:
    if cmap is None:
        cmap = copy.copy(plt.get_cmap('Blues'))
    elif isinstance(cmap, str):
        cmap = copy.copy(plt.get_cmap(cmap))

    if color_bad:
        cmap.set_bad(color_bad, 1.0)

    # Plotting the image using all the settings set above:
    im = ax.imshow(image,
                   cmap=cmap,
                   norm=norm,
                   origin=origin,
                   extent=extent,
                   interpolation='nearest',
                   **kwargs)

    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
    if title is not None:
        ax.set_title(title)
    ax.set_xlim([extent[0], extent[1]])
    ax.set_ylim([extent[2], extent[3]])

    if cbar:
        colorbar(im,
                 ax=ax,
                 loc=cbar,
                 size=cbar_size,
                 pad=cbar_pad,
                 label=clabel,
                 ticks=cbar_ticks,
                 ticklabels=cbar_ticklabels)

    # Settings for ticks:
    integer_locator = MaxNLocator(nbins=10, integer=True)
    ax.xaxis.set_major_locator(integer_locator)
    ax.xaxis.set_minor_locator(integer_locator)
    ax.yaxis.set_major_locator(integer_locator)
    ax.yaxis.set_minor_locator(integer_locator)
    ax.tick_params(which='both', direction='out', pad=5)
    ax.xaxis.tick_bottom()
    ax.yaxis.tick_left()

    return im