Ejemplo n.º 1
0
    def test_direct_init(self):

        s = DistanceToLonLat(R=6378.273)

        coord_meta = {}
        coord_meta['type'] = ('longitude', 'latitude')
        coord_meta['wrap'] = (360., None)
        coord_meta['unit'] = (u.deg, u.deg)
        coord_meta['name'] = 'lon', 'lat'
        fig = plt.figure(figsize=(4, 4))

        ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], transform=s, coord_meta=coord_meta)
        fig.add_axes(ax)

        ax.coords['lon'].grid(color='red', linestyle='solid', alpha=0.3)
        ax.coords['lat'].grid(color='blue', linestyle='solid', alpha=0.3)

        ax.coords['lon'].set_ticklabel(size=7, exclude_overlapping=True)
        ax.coords['lat'].set_ticklabel(size=7, exclude_overlapping=True)

        ax.coords['lon'].set_ticklabel_position('brtl')
        ax.coords['lat'].set_ticklabel_position('brtl')

        ax.coords['lon'].set_ticks(spacing=10. * u.deg)
        ax.coords['lat'].set_ticks(spacing=10. * u.deg)

        ax.set_xlim(-400., 500.)
        ax.set_ylim(-300., 400.)

        return fig
Ejemplo n.º 2
0
    def test_ticks_labels(self):
        fig = plt.figure(figsize=(6, 6))
        ax = WCSAxes(fig, [0.1, 0.1, 0.7, 0.7], wcs=None)
        fig.add_axes(ax)
        ax.set_xlim(-0.5, 2)
        ax.set_ylim(-0.5, 2)
        ax.coords[0].set_ticks(size=10, color='blue', alpha=0.2, width=1)
        ax.coords[1].set_ticks(size=20, color='red', alpha=0.9, width=1)
        ax.coords[0].set_ticks_position('all')
        ax.coords[1].set_ticks_position('all')
        ax.coords[0].set_axislabel('X-axis', size=20)
        ax.coords[1].set_axislabel('Y-axis', color='green', size=25,
                                   weight='regular', style='normal',
                                   family='cmtt10')
        ax.coords[0].set_axislabel_position('t')
        ax.coords[1].set_axislabel_position('r')
        ax.coords[0].set_ticklabel(color='purple', size=15, alpha=1,
                                   weight='light', style='normal',
                                   family='cmss10')
        ax.coords[1].set_ticklabel(color='black', size=18, alpha=0.9,
                                   weight='bold', family='cmr10')
        ax.coords[0].set_ticklabel_position('all')
        ax.coords[1].set_ticklabel_position('r')

        return fig
Ejemplo n.º 3
0
def time_basic_plot():

    fig = Figure()
    canvas = FigureCanvas(fig)

    ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS)
    fig.add_axes(ax)

    ax.set_xlim(-0.5, 148.5)
    ax.set_ylim(-0.5, 148.5)

    canvas.draw()
Ejemplo n.º 4
0
    def test_copy_frame_properties_change_wcs(self):

        # When WCS is changed, a new frame is created, so we need to make sure
        # that the color and linewidth are transferred over

        fig = plt.figure()
        ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8])
        fig.add_axes(ax)
        ax.coords.frame.set_linewidth(5)
        ax.coords.frame.set_color('purple')
        ax.reset_wcs()
        assert ax.coords.frame.get_linewidth() == 5
        assert ax.coords.frame.get_color() == 'purple'
Ejemplo n.º 5
0
def time_contourf_with_transform():

    fig = Figure()
    canvas = FigureCanvas(fig)

    ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS)
    fig.add_axes(ax)

    ax.contourf(DATA, transform=ax.get_transform(TWOMASS_WCS))

    # The limits are to make sure the contours are in the middle of the result
    ax.set_xlim(32.5, 150.5)
    ax.set_ylim(-64.5, 64.5)

    canvas.draw()
Ejemplo n.º 6
0
def time_basic_plot_with_grid_and_overlay():

    fig = Figure()
    canvas = FigureCanvas(fig)

    ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS)
    fig.add_axes(ax)

    ax.grid(color='red', alpha=0.5, linestyle='solid')

    ax.set_xlim(-0.5, 148.5)
    ax.set_ylim(-0.5, 148.5)

    overlay = ax.get_coords_overlay('fk5')
    overlay.grid(color='purple', ls='dotted')

    canvas.draw()
Ejemplo n.º 7
0
def time_basic_plot_with_grid():

    fig = Figure()
    canvas = FigureCanvas(fig)

    ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=MSX_WCS)
    fig.add_axes(ax)

    ax.grid(color='red', alpha=0.5, linestyle='solid')

    ax.set_xlim(-0.5, 148.5)
    ax.set_ylim(-0.5, 148.5)

    canvas.draw()
Ejemplo n.º 8
0
    def test_coords_overlay_auto_coord_meta(self):

        fig = plt.figure(figsize=(4, 4))

        ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=WCS(self.msx_header))
        fig.add_axes(ax)

        ax.grid(color='red', alpha=0.5, linestyle='solid')

        overlay = ax.get_coords_overlay('fk5')  # automatically sets coord_meta

        overlay.grid(color='black', alpha=0.5, linestyle='solid')

        overlay['ra'].set_ticks(color='black')
        overlay['dec'].set_ticks(color='black')

        ax.set_xlim(-0.5, 148.5)
        ax.set_ylim(-0.5, 148.5)

        return fig
Ejemplo n.º 9
0
    def test_coords_overlay(self):

        # Set up a simple WCS that maps pixels to non-projected distances
        wcs = WCS(naxis=2)
        wcs.wcs.ctype = ['x', 'y']
        wcs.wcs.cunit = ['km', 'km']
        wcs.wcs.crpix = [614.5, 856.5]
        wcs.wcs.cdelt = [6.25, 6.25]
        wcs.wcs.crval = [0., 0.]

        fig = plt.figure(figsize=(4, 4))

        ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7], wcs=wcs)
        fig.add_axes(ax)

        s = DistanceToLonLat(R=6378.273)

        ax.coords['x'].set_ticklabel_position('')
        ax.coords['y'].set_ticklabel_position('')

        coord_meta = {}
        coord_meta['type'] = ('longitude', 'latitude')
        coord_meta['wrap'] = (360., None)
        coord_meta['unit'] = (u.deg, u.deg)
        coord_meta['name'] = 'lon', 'lat'

        overlay = ax.get_coords_overlay(s, coord_meta=coord_meta)

        overlay.grid(color='red')
        overlay['lon'].grid(color='red', linestyle='solid', alpha=0.3)
        overlay['lat'].grid(color='blue', linestyle='solid', alpha=0.3)

        overlay['lon'].set_ticklabel(size=7, exclude_overlapping=True)
        overlay['lat'].set_ticklabel(size=7, exclude_overlapping=True)

        overlay['lon'].set_ticklabel_position('brtl')
        overlay['lat'].set_ticklabel_position('brtl')

        overlay['lon'].set_ticks(spacing=10. * u.deg)
        overlay['lat'].set_ticks(spacing=10. * u.deg)

        ax.set_xlim(-0.5, 1215.5)
        ax.set_ylim(-0.5, 1791.5)

        return fig
Ejemplo n.º 10
0
    def test_custom_frame(self):

        wcs = WCS(self.msx_header)

        fig = plt.figure(figsize=(4, 4))

        ax = WCSAxes(fig, [0.15, 0.15, 0.7, 0.7],
                     wcs=wcs,
                     frame_class=HexagonalFrame)
        fig.add_axes(ax)

        ax.coords.grid(color='white')

        im = ax.imshow(np.ones((149, 149)), vmin=0., vmax=2.,
                       origin='lower', cmap=plt.cm.gist_heat)

        minpad = {}
        minpad['a'] = minpad['d'] = 1
        minpad['b'] = minpad['c'] = minpad['e'] = minpad['f'] = 2.75

        ax.coords['glon'].set_axislabel("Longitude", minpad=minpad)
        ax.coords['glon'].set_axislabel_position('ad')

        ax.coords['glat'].set_axislabel("Latitude", minpad=minpad)
        ax.coords['glat'].set_axislabel_position('bcef')

        ax.coords['glon'].set_ticklabel_position('ad')
        ax.coords['glat'].set_ticklabel_position('bcef')

        # Set limits so that no labels overlap
        ax.set_xlim(5.5, 100.5)
        ax.set_ylim(5.5, 110.5)

        # Clip the image to the frame
        im.set_clip_path(ax.coords.frame.patch)

        return fig
Ejemplo n.º 11
0
    def test_rcparams(self):

        # Test custom rcParams

        with rc_context({

                'axes.labelcolor': 'purple',
                'axes.labelsize': 14,
                'axes.labelweight': 'bold',

                'axes.linewidth': 3,
                'axes.facecolor': '0.5',
                'axes.edgecolor': 'green',

                'xtick.color': 'red',
                'xtick.labelsize': 8,
                'xtick.direction': 'in',

                'xtick.minor.visible': True,
                'xtick.minor.size': 5,

                'xtick.major.size': 20,
                'xtick.major.width': 3,
                'xtick.major.pad': 10,

                'grid.color': 'blue',
                'grid.linestyle': ':',
                'grid.linewidth': 1,
                'grid.alpha': 0.5}):

            fig = plt.figure(figsize=(6, 6))
            ax = WCSAxes(fig, [0.15, 0.1, 0.7, 0.7], wcs=None)
            fig.add_axes(ax)
            ax.set_xlim(-0.5, 2)
            ax.set_ylim(-0.5, 2)
            ax.grid()
            ax.set_xlabel('X label')
            ax.set_ylabel('Y label')
            ax.coords[0].set_ticklabel(exclude_overlapping=True)
            ax.coords[1].set_ticklabel(exclude_overlapping=True)
            return fig
Ejemplo n.º 12
0
import matplotlib.cm as cm
from photutils import find_peaks
from astropy import wcs
from astropy.nddata import Cutout2D
from astropy import units as u

filename = './rcutout.fits'
image = fits.open(filename)

try:
    from astropy.wcs import WCS
    from astropy.visualization.wcsaxes import WCSAxes

    wcs = WCS(image[0].header)
    fig = plt.figure()
    ax = WCSAxes(fig, [.1,.1,.8,.8], wcs=wcs)
    fig.add_axes(ax)
except ImportError:
    ax = plt.subplot(111)

ax.imshow(image[0].data, cmap=cm.gray, vmin=0, vmax=0.00038, origin='lower')

region_name = './satpeaks4.reg'
r = pyregion.open(region_name).as_imagecoord(header=image[0].header)

from pyregion.mpl_helper import properties_func_default

def fixed_color(shape, saved_attrs):
    attr_list, attr_dict = saved_attrs
    attr_dict["color"] = "red"
    kwargs = properties_func_default(shape, (attr_list,attr_dict))
Ejemplo n.º 13
0
def plot_image_rs_full(map_in,
                       xrange=[-1.4, 1.4],
                       yrange=[-1.4, 1.4],
                       log_min=0.5,
                       log_max=3.5,
                       cmap_name=None,
                       outfile=None,
                       dpi=100,
                       save_interactive=False):
    """
    Quick method to plot a map by specifying the x and y range in SOLAR coordinates.
    - Unlike plot_image_rs, here the image fills the entire frame, with no outside annotations
      like axes labels or colorbars

    - xrange and yrange are two elements lists or tuples that specify the solar coords in Rs
      - e.g. xrange=[-1.3, -1.3], yrange=[-1.3, 1.3]

    - if a output file is specified, it will switch to a non-interactive backend and save the file
      without showing the plot (unless save_interactive=True).

    - cmap_name (optional) is a string that specifies a sunpy or matplotlib colormap

    ToDo:
      - put white annotations in the corners that describe the image (time, inst, clon, b0)
      - overplot some solar grid lines (e.g. the lat=0 line and/or various clons)
    """
    # I don't want to modify the input map at all --> copy the map object just in case
    map = copy.deepcopy(map_in)

    # info from the map
    rs_obs = map.rsun_obs

    # get the coordinate positions of the x and y ranges
    x0 = xrange[0] * rs_obs.value * u.arcsec
    x1 = xrange[1] * rs_obs.value * u.arcsec
    y0 = yrange[0] * rs_obs.value * u.arcsec
    y1 = yrange[1] * rs_obs.value * u.arcsec
    bot_left = SkyCoord(x0, y0, frame=map.coordinate_frame)
    top_right = SkyCoord(x1, y1, frame=map.coordinate_frame)

    # experiment with different styles of plotting the x and y window
    # using "limits" lets you plot outside of the image window, which can be important
    # for aligning images.
    plot_method = 'limits'

    if plot_method == 'submap':
        map = map.submap(bot_left, top_right)

    # setup the optional colormap
    if cmap_name is not None:
        cmap = plt.get_cmap(cmap_name)
        map.plot_settings['cmap'] = cmap

    # Set the map plot min/max
    pmin = 10.0**(log_min)
    pmax = 10.0**(log_max)
    map.plot_settings['norm'] = colors.LogNorm(pmin, pmax)

    # Change the colormap so undefined values don't show up white
    map.plot_settings['cmap'].set_bad(color='black')

    # if saving a file, don't use the interactive backend
    if outfile is not None and not save_interactive:
        matplotlib.use(mpl_backend_non_interactive)

    # setup the figure
    fig = plt.figure(figsize=(9, 9))

    # Manually specify the axis (vs. getting through map.plot) this way you have more control
    axis = WCSAxes(fig, [0.0, 0.0, 1.0, 1.0], wcs=map.wcs)
    fig.add_axes(
        axis)  # note that the axes have to be explicitly added to the figure

    # plot the image
    map.plot(axes=axis)

    # example for adjusting the tick spacing (see astropy examples for WCSAxes)
    custom_ticks = True
    if custom_ticks:
        spacing = map.rsun_obs
        axis.coords[0].set_ticks(spacing=spacing)
        axis.coords[1].set_ticks(spacing=spacing)

    # if plot is NOT a submap, compute the pixel positions and change the matplotlib limits
    if plot_method == 'limits':
        pp_bot_left = map.world_to_pixel(bot_left)
        pp_top_right = map.world_to_pixel(top_right)
        axis.set_xlim(left=pp_bot_left.x.value, right=pp_top_right.x.value)
        axis.set_ylim(bottom=pp_bot_left.y.value, top=pp_top_right.y.value)

    # save the plot (optional)
    if outfile is not None:
        print("Saving image plot to: " + outfile)
        fig.savefig(outfile, dpi=dpi)
        # revert to the default MPL backend
        if not save_interactive:
            plt.close()
            matplotlib.use(mpl_backend_default)
        else:
            plt.show()
    else:
        plt.show()
Ejemplo n.º 14
0
def plot(LSM, fileName=None, labelBy=None):
    """
    Shows a simple plot of the sky model.

    The circles in the plot are scaled with flux. If the sky model is grouped
    into patches, sources are colored by patch and the patch positions are
    indicated with stars.

    Parameters
    ----------
    LSM : SkyModel object
        Input sky model
    fileName : str, optional
        If given, the plot is saved to a file instead of displayed
    labelBy : str, optional
        One of 'source' or 'patch': label points using source names ('source') or
        patch names ('patch')

    Examples:
    ---------
    Plot and display to the screen::

        >>> LSM = lsmtool.load('sky.model')
        >>> plot(LSM)

    Plot and save to a PDF file::

        >>> plot(LSM, 'sky_plot.pdf')

    """
    try:
        import os
        if 'DISPLAY' not in os.environ:
            import matplotlib
            if matplotlib.get_backend() is not 'Agg':
                matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        from matplotlib.ticker import FuncFormatter
    except Exception as e:
        raise ImportError('PyPlot could not be imported. Plotting is not '
                          'available: {0}'.format(e.message))
    try:
        try:
            from astropy.visualization.wcsaxes import WCSAxes
            hasWCSaxes = True
        except:
            from wcsaxes import WCSAxes
            hasWCSaxes = True
    except:
        hasWCSaxes = False
    import numpy as np
    from ..operations_lib import radec2xy, makeWCS
    global midRA, midDec, ymin, xmin

    if len(LSM) == 0:
        log.error('Sky model is empty.')
        return

    fig = plt.figure(1, figsize=(7.66, 7))
    plt.clf()
    x, y, midRA, midDec = LSM._getXY()
    if hasWCSaxes:
        wcs = makeWCS(midRA, midDec)
        ax = WCSAxes(fig, [0.16, 0.1, 0.8, 0.8], wcs=wcs)
        fig.add_axes(ax)
    else:
        ax = plt.gca()
    if LSM.hasPatches:
        nsrc = len(LSM.getPatchNames())
    else:
        nsrc = len(LSM)
    sm = plt.cm.ScalarMappable(cmap=plt.cm.Set3,
                               norm=plt.Normalize(vmin=0, vmax=nsrc))
    sm._A = []

    # Set symbol sizes by flux, making sure no symbol is smaller than 50 or
    # larger than 1000
    s = []
    fluxes = LSM.getColValues('I')
    if len(fluxes[fluxes > 0.0]) == 0:
        minflux = 0.0
    else:
        minflux = np.min(fluxes[fluxes > 0.0])
    for flux in LSM.getColValues('I'):
        if flux > 0.0:
            s.append(min(1000.0,
                         (1.0 + 2.0 * np.log10(flux / minflux)) * 50.0))
        else:
            s.append(50.0)

    # Color sources by patch if grouped
    c = [0] * len(LSM)
    cp = []
    if LSM.hasPatches:
        for p, patchName in enumerate(LSM.getPatchNames()):
            indices = LSM.getRowIndex(patchName)
            cp.append(sm.to_rgba(p))
            for ind in indices:
                c[ind] = sm.to_rgba(p)
    else:
        c = [sm.to_rgba(0)] * nsrc

    # Plot sources
    if hasWCSaxes:
        RA = LSM.getColValues('Ra')
        Dec = LSM.getColValues('Dec')
        ax.set_xlim(np.min(x) - 20, np.max(x) + 20)
        ax.set_ylim(np.min(y) - 20, np.max(y) + 20)
    plt.scatter(x, y, s=s, c=c)

    if LSM.hasPatches:
        RAp, Decp = LSM.getPatchPositions(asArray=True)
        goodInd = np.where((RAp != 0.0) & (Decp != 0.0))
        if len(goodInd[0]) < len(RAp):
            log.info('Some patch positions are unset. Run setPatchPositions() '
                     'before plotting to see patch positions and patch names.')
        xp, yp = radec2xy(RAp[goodInd], Decp[goodInd], midRA, midDec)
        plt.scatter(xp, yp, s=100, c=cp, marker='*')

    # Set axis labels, etc.
    if hasWCSaxes:
        RAAxis = ax.coords['ra']
        RAAxis.set_axislabel('RA', minpad=0.75)
        RAAxis.set_major_formatter('hh:mm:ss')
        DecAxis = ax.coords['dec']
        DecAxis.set_axislabel('Dec', minpad=0.75)
        DecAxis.set_major_formatter('dd:mm:ss')
        ax.coords.grid(color='black', alpha=0.5, linestyle='solid')
    else:
        plt.xlabel("RA (arb. units)")
        plt.ylabel("Dec (arb. units)")

    if labelBy is not None:
        if labelBy.lower() == 'source':
            labels = LSM.getColValues('name')
            xls = x
            yls = y
        elif labelBy.lower() == 'patch':
            if LSM.hasPatches:
                labels = LSM.getPatchNames()
                xls = xp
                yls = yp
            else:
                labels = LSM.getColValues('name')
                xls = x
                yls = y
        else:
            raise ValueError(
                "The lableBy parameter must be one of 'source' or "
                "'patch'.")
        for label, xl, yl in zip(labels, xls, yls):
            plt.annotate(label,
                         xy=(xl, yl),
                         xytext=(-2, 2),
                         textcoords='offset points',
                         ha='right',
                         va='bottom')

    # Define coodinate formater to show RA and Dec under mouse pointer
    RAformatter = FuncFormatter(RAtickformatter)
    ax.format_coord = formatCoord

    if fileName is not None:
        plt.savefig(fileName)
    else:
        plt.show()
    plt.close(fig)
Ejemplo n.º 15
0
def plot_state(directions_list, trim_names=True):
    """
    Plots the facets of a run
    """
    global midRA, midDec, fig, at, selected_direction, choose_from_list
    selected_direction = None
    choose_from_list = False

    # Set up coordinate system and figure
    points, midRA, midDec = factor.directions.getxy(directions_list)
    fig = plt.figure(1, figsize=(10, 9))
    if hasWCSaxes:
        wcs = factor.directions.makeWCS(midRA, midDec)
        ax = WCSAxes(fig, [0.16, 0.1, 0.8, 0.8], wcs=wcs)
        fig.add_axes(ax)
    else:
        ax = plt.gca()

    field_x = min(points[0])
    field_y = max(points[1])
    adjust_xy = True
    while adjust_xy:
        adjust_xy = False
        for xy in points:
            dist = np.sqrt((xy[0] - field_x)**2 + (xy[1] - field_y)**2)
            if dist < 10.0:
                field_x -= 1
                field_y += 1
                adjust_xy = True
                break
    field_ra, field_dec = factor.directions.xy2radec([field_x], [field_y],
                                                     refRA=midRA,
                                                     refDec=midDec)
    field = Direction('field',
                      field_ra[0],
                      field_dec[0],
                      factor_working_dir=directions_list[0].working_dir)
    directions_list.append(field)

    ax.set_title('Overview of Factor run in\n{}'.format(
        directions_list[0].working_dir))

    # Plot facets
    markers = []
    for direction in directions_list:
        if direction.name != 'field':
            vertices = read_vertices(direction.vertices_file)
            RAverts = vertices[0]
            Decverts = vertices[1]
            xverts, yverts = factor.directions.radec2xy(RAverts,
                                                        Decverts,
                                                        refRA=midRA,
                                                        refDec=midDec)
            xyverts = [np.array([xp, yp]) for xp, yp in zip(xverts, yverts)]
            mpl_poly = Polygon(np.array(xyverts),
                               edgecolor='#a9a9a9',
                               facecolor='#F2F2F2',
                               clip_box=ax.bbox,
                               picker=3.0,
                               linewidth=2)
        else:
            xverts = [field_x]
            yverts = [field_y]
            mpl_poly = Circle((field_x, field_y),
                              radius=5.0,
                              edgecolor='#a9a9a9',
                              facecolor='#F2F2F2',
                              clip_box=ax.bbox,
                              picker=3.0,
                              linewidth=2)
        mpl_poly.facet_name = direction.name
        mpl_poly.completed_ops = get_completed_ops(direction)
        mpl_poly.started_ops = get_started_ops(direction)
        mpl_poly.current_op = get_current_op(direction)
        set_patch_color(mpl_poly, direction)
        ax.add_patch(mpl_poly)

        # Add facet names
        if direction.name != 'field':
            poly_tuple = tuple([(xp, yp) for xp, yp in zip(xverts, yverts)])
            xmid = SPolygon(poly_tuple).centroid.x
            ymid = SPolygon(poly_tuple).centroid.y
        else:
            xmid = field_x
            ymid = field_y
        if trim_names:
            name = direction.name.split('_')[-1]
        else:
            name = direction.name
        marker = ax.text(xmid,
                         ymid,
                         name,
                         color='k',
                         clip_on=True,
                         clip_box=ax.bbox,
                         ha='center',
                         va='bottom')
        marker.set_zorder(1001)
        markers.append(marker)

    # Add info box
    at = AnchoredText("Selected direction: None",
                      prop=dict(size=12),
                      frameon=True,
                      loc=3)
    at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
    at.set_zorder(1002)
    ax.add_artist(at)

    ax.relim()
    ax.autoscale()
    ax.set_aspect('equal')

    if hasWCSaxes:
        RAAxis = ax.coords['ra']
        RAAxis.set_axislabel('RA', minpad=0.75)
        RAAxis.set_major_formatter('hh:mm:ss')
        DecAxis = ax.coords['dec']
        DecAxis.set_axislabel('Dec', minpad=0.75)
        DecAxis.set_major_formatter('dd:mm:ss')
        ax.coords.grid(color='black', alpha=0.5, linestyle='solid')
    else:
        plt.xlabel("RA (arb. units)")
        plt.ylabel("Dec (arb. units)")

    # Define coodinate formater to show RA and Dec under mouse pointer
    ax.format_coord = formatCoord

    # Show legend
    not_processed_patch = plt.Rectangle((0, 0),
                                        1,
                                        1,
                                        edgecolor='#a9a9a9',
                                        facecolor='#F2F2F2',
                                        linewidth=2)
    processing_patch = plt.Rectangle((0, 0),
                                     1,
                                     1,
                                     edgecolor='#a9a9a9',
                                     facecolor='#F2F5A9',
                                     linewidth=2)
    selfcal_ok_patch = plt.Rectangle((0, 0),
                                     1,
                                     1,
                                     edgecolor='#a9a9a9',
                                     facecolor='#A9F5A9',
                                     linewidth=2)
    selfcal_not_ok_patch = plt.Rectangle((0, 0),
                                         1,
                                         1,
                                         edgecolor='#a9a9a9',
                                         facecolor='#A4A4A4',
                                         linewidth=2)
    processing_error = plt.Rectangle((0, 0),
                                     1,
                                     1,
                                     edgecolor='#a9a9a9',
                                     facecolor='#F5A9A9',
                                     linewidth=2)
    patch_list = [
        not_processed_patch, processing_patch, processing_error,
        selfcal_not_ok_patch, selfcal_ok_patch
    ]
    label_list = [
        'Unprocessed', 'Processing', 'Pipeline Error', 'Selfcal Failed',
        'Selfcal OK'
    ]
    for i in range(options['reimages']):
        label_list.append('Image ' + str(i + 1))
        color = (0.66 / (i + 2)**0.5, 0.96 / (i + 2)**0.5, 0.66 / (i + 2)**0.5,
                 1.0)
        reimage_patch = plt.Rectangle((0, 0),
                                      1,
                                      1,
                                      edgecolor='#a9a9a9',
                                      facecolor=color,
                                      linewidth=2)
        patch_list.append(reimage_patch)
    l = ax.legend(patch_list, label_list, loc="upper right")
    l.set_zorder(1002)

    # Add check for mouse clicks and key presses
    fig.canvas.mpl_connect('pick_event', on_pick)
    fig.canvas.mpl_connect('key_press_event', on_press)

    # Add timer to update the plot every 60 seconds
    timer = fig.canvas.new_timer(interval=60000)
    timer.add_callback(update_plot)
    timer.start()

    # Show plot
    plt.show()
    plt.close(fig)

    # Clean up any temp casacore images
    if not hasaplpy:
        if os.path.exists('/tmp/tempimage'):
            try:
                shutil.rmtree('/tmp/tempimage')
            except OSError:
                pass
Ejemplo n.º 16
0
    def test_update_clip_path_change_wcs(self, tmpdir):

        # When WCS is changed, a new frame is created, so we need to make sure
        # that the path is carried over to the new frame.

        fig = plt.figure()
        ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], aspect='equal')

        fig.add_axes(ax)

        ax.set_xlim(0., 2.)
        ax.set_ylim(0., 2.)

        # Force drawing, which freezes the clip path returned by WCSAxes
        fig.savefig(tmpdir.join('nothing').strpath)

        ax.reset_wcs()

        ax.imshow(np.zeros((12, 4)))

        ax.set_xlim(-0.5, 3.5)
        ax.set_ylim(-0.5, 11.5)

        return fig
Ejemplo n.º 17
0
def plot_image(img_file,
               hdu="IMAGE",
               stretch='linear',
               vmin=None,
               vmax=None,
               facecolor='black',
               center=None,
               width=None,
               figsize=(10, 10),
               cmap=None):
    """
    Plot a FITS image created by SOXS using Matplotlib.

    Parameters
    ----------
    img_file : str
        The on-disk FITS image to plot. 
    hdu : str or int, optional
        The image extension to plot. Default is "IMAGE"
    stretch : str, optional
        The stretch to apply to the colorbar scale. Options are "linear",
        "log", and "sqrt". Default: "linear"
    vmin : float, optional
        The minimum value of the colorbar. If not set, it will be the minimum
        value in the image.
    vmax : float, optional
        The maximum value of the colorbar. If not set, it will be the maximum
        value in the image.
    facecolor : str, optional
        The color of zero-valued pixels. Default: "black"
    center : array-like
        A 2-element object giving an (RA, Dec) coordinate for the center
        in degrees. If not set, the reference pixel of the image (usually
        the center) is used.
    width : float, optional
        The width of the image in degrees. If not set, the width of the
        entire image will be used.
    figsize : tuple, optional
        A 2-tuple giving the size of the image in inches, e.g. (12, 15).
        Default: (10,10)
    cmap : str, optional
        The colormap to be used. If not set, the default Matplotlib
        colormap will be used.

    Returns
    -------
    A tuple of the :class:`~matplotlib.figure.Figure` and the 
    :class:`~matplotlib.axes.Axes` objects.
    """
    import matplotlib.pyplot as plt
    from matplotlib.colors import PowerNorm, LogNorm, Normalize
    from astropy.wcs.utils import proj_plane_pixel_scales
    from astropy.visualization.wcsaxes import WCSAxes
    if stretch == "linear":
        norm = Normalize(vmin=vmin, vmax=vmax)
    elif stretch == "log":
        norm = LogNorm(vmin=vmin, vmax=vmax)
    elif stretch == "sqrt":
        norm = PowerNorm(0.5, vmin=vmin, vmax=vmax)
    else:
        raise RuntimeError(f"'{stretch}' is not a valid stretch!")
    with fits.open(img_file) as f:
        hdu = f[hdu]
        w = wcs.WCS(hdu.header)
        pix_scale = proj_plane_pixel_scales(w)
        if center is None:
            center = w.wcs.crpix
        else:
            center = w.wcs_world2pix(center[0], center[1], 0)
        if width is None:
            dx_pix = 0.5 * hdu.shape[0]
            dy_pix = 0.5 * hdu.shape[1]
        else:
            dx_pix = width / pix_scale[0]
            dy_pix = width / pix_scale[1]
        fig = plt.figure(figsize=figsize)
        ax = WCSAxes(fig, [0.15, 0.1, 0.8, 0.8], wcs=w)
        fig.add_axes(ax)
        im = ax.imshow(hdu.data, norm=norm, cmap=cmap)
        ax.set_xlim(center[0] - 0.5 * dx_pix, center[0] + 0.5 * dx_pix)
        ax.set_ylim(center[1] - 0.5 * dy_pix, center[1] + 0.5 * dy_pix)
        ax.set_facecolor(facecolor)
        cbar = plt.colorbar(im)
    return fig, ax
Ejemplo n.º 18
0
def plot(LSM, fileName=None, labelBy=None):
    """
    Shows a simple plot of the sky model.

    The circles in the plot are scaled with flux. If the sky model is grouped
    into patches, sources are colored by patch and the patch positions are
    indicated with stars.

    Parameters
    ----------
    LSM : SkyModel object
        Input sky model
    fileName : str, optional
        If given, the plot is saved to a file instead of displayed
    labelBy : str, optional
        One of 'source' or 'patch': label points using source names ('source') or
        patch names ('patch')

    Examples:
    ---------
    Plot and display to the screen::

        >>> LSM = lsmtool.load('sky.model')
        >>> plot(LSM)

    Plot and save to a PDF file::

        >>> plot(LSM, 'sky_plot.pdf')

    """
    try:
        import os
        if 'DISPLAY' not in os.environ:
            import matplotlib
            if matplotlib.get_backend() is not 'Agg':
                matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        from matplotlib.ticker import FuncFormatter
    except Exception as e:
        raise ImportError('PyPlot could not be imported. Plotting is not '
            'available: {0}'.format(e.message))
    try:
        try:
            from astropy.visualization.wcsaxes import WCSAxes
            hasWCSaxes = True
        except:
            from wcsaxes import WCSAxes
            hasWCSaxes = True
    except:
        hasWCSaxes = False
    import numpy as np
    from ..operations_lib import radec2xy, makeWCS
    global midRA, midDec, ymin, xmin

    if len(LSM) == 0:
        log.error('Sky model is empty.')
        return

    fig = plt.figure(1,figsize=(7.66,7))
    plt.clf()
    x, y, midRA, midDec  = LSM._getXY()
    if hasWCSaxes:
        wcs = makeWCS(midRA, midDec)
        ax = WCSAxes(fig, [0.16, 0.1, 0.8, 0.8], wcs=wcs)
        fig.add_axes(ax)
    else:
        ax = plt.gca()
    if LSM.hasPatches:
        nsrc = len(LSM.getPatchNames())
    else:
        nsrc = len(LSM)
    sm = plt.cm.ScalarMappable(cmap=plt.cm.Set3, norm=plt.Normalize(vmin=0,
        vmax=nsrc))
    sm._A = []

    # Set symbol sizes by flux, making sure no symbol is smaller than 50 or
    # larger than 1000
    s = []
    fluxes = LSM.getColValues('I')
    if len(fluxes[fluxes > 0.0]) == 0:
        minflux = 0.0
    else:
        minflux = np.min(fluxes[fluxes > 0.0])
    for flux in LSM.getColValues('I'):
        if flux > 0.0:
            s.append(min(1000.0, (1.0+2.0*np.log10(flux/minflux))*50.0))
        else:
            s.append(50.0)

    # Color sources by patch if grouped
    c = [0]*len(LSM)
    cp = []
    if LSM.hasPatches:
        for p, patchName in enumerate(LSM.getPatchNames()):
            indices = LSM.getRowIndex(patchName)
            cp.append(sm.to_rgba(p))
            for ind in indices:
                c[ind] = sm.to_rgba(p)
    else:
        c = [sm.to_rgba(0)] * nsrc

    # Plot sources
    if hasWCSaxes:
        RA = LSM.getColValues('Ra')
        Dec = LSM.getColValues('Dec')
        ax.set_xlim(np.min(x)-20, np.max(x)+20)
        ax.set_ylim(np.min(y)-20, np.max(y)+20)
    plt.scatter(x, y, s=s, c=c)

    if LSM.hasPatches:
        RAp, Decp = LSM.getPatchPositions(asArray=True)
        goodInd = np.where( (RAp != 0.0) & (Decp != 0.0) )
        if len(goodInd[0]) < len(RAp):
            log.info('Some patch positions are unset. Run setPatchPositions() '
                'before plotting to see patch positions and patch names.')
        xp, yp = radec2xy(RAp[goodInd], Decp[goodInd], midRA, midDec)
        plt.scatter(xp, yp, s=100, c=cp, marker='*')

    # Set axis labels, etc.
    if hasWCSaxes:
        RAAxis = ax.coords['ra']
        RAAxis.set_axislabel('RA', minpad=0.75)
        RAAxis.set_major_formatter('hh:mm:ss')
        DecAxis = ax.coords['dec']
        DecAxis.set_axislabel('Dec', minpad=0.75)
        DecAxis.set_major_formatter('dd:mm:ss')
        ax.coords.grid(color='black', alpha=0.5, linestyle='solid')
    else:
        plt.xlabel("RA (arb. units)")
        plt.ylabel("Dec (arb. units)")

    if labelBy is not None:
        if labelBy.lower() == 'source':
            labels = LSM.getColValues('name')
            xls = x
            yls = y
        elif labelBy.lower() == 'patch':
            if LSM.hasPatches:
                labels = LSM.getPatchNames()
                xls = xp
                yls = yp
            else:
                labels = LSM.getColValues('name')
                xls = x
                yls = y
        else:
            raise ValueError("The lableBy parameter must be one of 'source' or "
                "'patch'.")
        for label, xl, yl in zip(labels, xls, yls):
            plt.annotate(label, xy = (xl, yl), xytext = (-2, 2), textcoords=
                'offset points', ha='right', va='bottom')

    # Define coodinate formater to show RA and Dec under mouse pointer
    RAformatter = FuncFormatter(RAtickformatter)
    ax.format_coord = formatCoord

    if fileName is not None:
        plt.savefig(fileName)
    else:
        plt.show()
    plt.close(fig)
Ejemplo n.º 19
0
    def test_update_clip_path_rectangular(self, tmpdir):

        fig = plt.figure()
        ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], aspect='equal')

        fig.add_axes(ax)

        ax.set_xlim(0., 2.)
        ax.set_ylim(0., 2.)

        # Force drawing, which freezes the clip path returned by WCSAxes
        fig.savefig(tmpdir.join('nothing').strpath)

        ax.imshow(np.zeros((12, 4)))

        ax.set_xlim(-0.5, 3.5)
        ax.set_ylim(-0.5, 11.5)

        return fig
Ejemplo n.º 20
0
    def test_update_clip_path_change_wcs(self, tmpdir):

        # When WCS is changed, a new frame is created, so we need to make sure
        # that the path is carried over to the new frame.

        fig = plt.figure()
        ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], aspect='equal')

        fig.add_axes(ax)

        ax.set_xlim(0., 2.)
        ax.set_ylim(0., 2.)

        # Force drawing, which freezes the clip path returned by WCSAxes
        fig.savefig(tmpdir.join('nothing').strpath)

        ax.reset_wcs()

        ax.imshow(np.zeros((12, 4)))

        ax.set_xlim(-0.5, 3.5)
        ax.set_ylim(-0.5, 11.5)

        return fig
Ejemplo n.º 21
0
    def plot(self,
             center,
             width,
             s=None,
             c=None,
             marker=None,
             stride=1,
             emin=None,
             emax=None,
             label=None,
             fontsize=18,
             fig=None,
             ax=None,
             **kwargs):
        """
        Plot event coordinates from this photon list in a scatter plot, 
        optionally restricting the photon energies which are plotted
        and using only a subset of the photons. 

        Parameters
        ----------
        center : array-like
            The RA, Dec of the center of the plot in degrees.
        width : float, (value, unit) tuple, or :class:`~astropy.units.Quantity`
            The width of the plot in arcminutes.
        s : integer, optional
            Size of the scatter marker in points^2.
        c : string, optional
            The color of the points.
        marker : string, optional
            The marker to use for the points in the scatter plot. Default: 'o'
        stride : integer, optional
            Plot every *stride* events. Default: 1
        emin : float, (value, unit) tuple, or :class:`~astropy.units.Quantity`
            The minimum energy of the photons to plot. Default is
            the minimum energy in the list.
        emax : float, (value, unit) tuple, or :class:`~astropy.units.Quantity`
            The maximum energy of the photons to plot. Default is
            the maximum energy in the list.
        label : string, optional
            The label of the spectrum. Default: None
        fontsize : int
            Font size for labels and axes. Default: 18
        fig : :class:`~matplotlib.figure.Figure`, optional
            A Figure instance to plot in. Default: None, one will be
            created if not provided.
        ax : :class:`~matplotlib.axes.Axes`, optional
            An Axes instance to plot in. Default: None, one will be
            created if not provided.
        """
        import matplotlib.pyplot as plt
        from astropy.visualization.wcsaxes import WCSAxes
        if fig is None:
            fig = plt.figure(figsize=(10, 10))
        if ax is None:
            wcs = construct_wcs(center[0], center[1])
            ax = WCSAxes(fig, [0.15, 0.1, 0.8, 0.8], wcs=wcs)
            fig.add_axes(ax)
        else:
            wcs = ax.wcs
        if emin is None:
            emin = self.energy.value.min()
        else:
            emin = parse_value(emin, "keV")
        if emax is None:
            emax = self.energy.value.max()
        else:
            emax = parse_value(emax, "keV")
        idxs = np.logical_and(self.energy.value >= emin,
                              self.energy.value <= emax)
        ra = self.ra[idxs][::stride].value
        dec = self.dec[idxs][::stride].value
        x, y = wcs.wcs_world2pix(ra, dec, 1)
        ax.scatter(x, y, s=s, c=c, marker=marker, label=label, **kwargs)
        x0, y0 = wcs.wcs_world2pix(center[0], center[1], 1)
        width = parse_value(width, "arcmin") * 60.0
        ax.set_xlim(x0 - 0.5 * width, x0 + 0.5 * width)
        ax.set_ylim(y0 - 0.5 * width, y0 + 0.5 * width)
        ax.set_xlabel("RA")
        ax.set_ylabel("Dec")
        ax.tick_params(axis='both', labelsize=fontsize)
        return fig, ax
Ejemplo n.º 22
0
def plot_state(directions_list, trim_names=True):
    """
    Plots the facets of a run
    """
    global midRA, midDec, fig, at, selected_direction, choose_from_list
    selected_direction = None
    choose_from_list = False

    # Set up coordinate system and figure
    points, midRA, midDec = factor.directions.getxy(directions_list)
    fig = plt.figure(1, figsize=(10,9))
    if hasWCSaxes:
        wcs = factor.directions.makeWCS(midRA, midDec)
        ax = WCSAxes(fig, [0.16, 0.1, 0.8, 0.8], wcs=wcs)
        fig.add_axes(ax)
    else:
        ax = plt.gca()

    field_x = min(points[0])
    field_y = max(points[1])
    adjust_xy = True
    while adjust_xy:
        adjust_xy = False
        for xy in points:
            dist = np.sqrt( (xy[0] - field_x)**2 + (xy[1] - field_y)**2 )
            if dist < 10.0:
                field_x -= 1
                field_y += 1
                adjust_xy = True
                break
    field_ra, field_dec = factor.directions.xy2radec([field_x], [field_y],
        refRA=midRA, refDec=midDec)
    field = Direction('field', field_ra[0], field_dec[0],
        factor_working_dir=directions_list[0].working_dir)
    directions_list.append(field)

    ax.set_title('Overview of Factor run in\n{}'.format(directions_list[0].working_dir))

    # Plot facets
    markers = []
    for direction in directions_list:
        if direction.name != 'field':
            vertices = read_vertices(direction.vertices_file)
            RAverts = vertices[0]
            Decverts = vertices[1]
            xverts, yverts = factor.directions.radec2xy(RAverts, Decverts,
                refRA=midRA, refDec=midDec)
            xyverts = [np.array([xp, yp]) for xp, yp in zip(xverts, yverts)]
            mpl_poly = Polygon(np.array(xyverts), edgecolor='#a9a9a9', facecolor='#F2F2F2',
                clip_box=ax.bbox, picker=3.0, linewidth=2)
        else:
            xverts = [field_x]
            yverts = [field_y]
            mpl_poly = Circle((field_x, field_y), radius=5.0, edgecolor='#a9a9a9', facecolor='#F2F2F2',
                clip_box=ax.bbox, picker=3.0, linewidth=2)
        mpl_poly.facet_name = direction.name
        mpl_poly.completed_ops = get_completed_ops(direction)
        mpl_poly.started_ops = get_started_ops(direction)
        mpl_poly.current_op = get_current_op(direction)
        set_patch_color(mpl_poly, direction)
        ax.add_patch(mpl_poly)

        # Add facet names
        if direction.name != 'field':
            poly_tuple = tuple([(xp, yp) for xp, yp in zip(xverts, yverts)])
            xmid = SPolygon(poly_tuple).centroid.x
            ymid = SPolygon(poly_tuple).centroid.y
        else:
            xmid = field_x
            ymid = field_y
        if trim_names:
            name = direction.name.split('_')[-1]
        else:
            name = direction.name
        marker = ax.text(xmid, ymid, name, color='k', clip_on=True,
            clip_box=ax.bbox, ha='center', va='bottom')
        marker.set_zorder(1001)
        markers.append(marker)

    # Add info box
    at = AnchoredText("Selected direction: None", prop=dict(size=12), frameon=True,
        loc=3)
    at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
    at.set_zorder(1002)
    ax.add_artist(at)

    ax.relim()
    ax.autoscale()
    ax.set_aspect('equal')

    if hasWCSaxes:
        RAAxis = ax.coords['ra']
        RAAxis.set_axislabel('RA', minpad=0.75)
        RAAxis.set_major_formatter('hh:mm:ss')
        DecAxis = ax.coords['dec']
        DecAxis.set_axislabel('Dec', minpad=0.75)
        DecAxis.set_major_formatter('dd:mm:ss')
        ax.coords.grid(color='black', alpha=0.5, linestyle='solid')
    else:
        plt.xlabel("RA (arb. units)")
        plt.ylabel("Dec (arb. units)")

    # Define coodinate formater to show RA and Dec under mouse pointer
    ax.format_coord = formatCoord

    # Show legend
    not_processed_patch = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9',
        facecolor='#F2F2F2', linewidth=2)
    processing_patch = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9',
        facecolor='#F2F5A9', linewidth=2)
    selfcal_ok_patch = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9',
        facecolor='#A9F5A9', linewidth=2)
    selfcal_not_ok_patch = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9',
        facecolor='#A4A4A4', linewidth=2)
    processing_error = plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9',
        facecolor='#F5A9A9', linewidth=2)
    patch_list=[not_processed_patch, processing_patch, processing_error, selfcal_not_ok_patch, selfcal_ok_patch]
    label_list=['Unprocessed', 'Processing', 'Pipeline Error', 'Selfcal Failed', 'Selfcal OK']
    for i in range(options['reimages']):
        label_list.append('Image '+str(i+1))
        color=(0.66/(i+2)**0.5, 0.96/(i+2)**0.5, 0.66/(i+2)**0.5, 1.0)
        reimage_patch=plt.Rectangle((0, 0), 1, 1, edgecolor='#a9a9a9',
            facecolor=color, linewidth=2)
        patch_list.append(reimage_patch)
    l = ax.legend(patch_list, label_list, loc="upper right")
    l.set_zorder(1002)

    # Add check for mouse clicks and key presses
    fig.canvas.mpl_connect('pick_event', on_pick)
    fig.canvas.mpl_connect('key_press_event', on_press)

    # Add timer to update the plot every 60 seconds
    timer = fig.canvas.new_timer(interval=60000)
    timer.add_callback(update_plot)
    timer.start()

    # Show plot
    plt.show()
    plt.close(fig)

    # Clean up any temp casacore images
    if not hasaplpy:
        if os.path.exists('/tmp/tempimage'):
            try:
                shutil.rmtree('/tmp/tempimage')
            except OSError:
                pass
Ejemplo n.º 23
0
    def plot(self, center, width, s=None, c=None, marker=None, stride=1,
             emin=None, emax=None, label=None, fontsize=18, fig=None, 
             ax=None, **kwargs):
        """
        Plot event coordinates from this photon list in a scatter plot, 
        optionally restricting the photon energies which are plotted
        and using only a subset of the photons. 

        Parameters
        ----------
        center : array-like
            The RA, Dec of the center of the plot in degrees.
        width : float, (value, unit) tuple, or :class:`~astropy.units.Quantity`
            The width of the plot in arcminutes.
        s : integer, optional
            Size of the scatter marker in points^2.
        c : string, optional
            The color of the points.
        marker : string, optional
            The marker to use for the points in the scatter plot. Default: 'o'
        stride : integer, optional
            Plot every *stride* events. Default: 1
        emin : float, (value, unit) tuple, or :class:`~astropy.units.Quantity`
            The minimum energy of the photons to plot. Default is
            the minimum energy in the list.
        emax : float, (value, unit) tuple, or :class:`~astropy.units.Quantity`
            The maximum energy of the photons to plot. Default is
            the maximum energy in the list.
        label : string, optional
            The label of the spectrum. Default: None
        fontsize : int
            Font size for labels and axes. Default: 18
        fig : :class:`~matplotlib.figure.Figure`, optional
            A Figure instance to plot in. Default: None, one will be
            created if not provided.
        ax : :class:`~matplotlib.axes.Axes`, optional
            An Axes instance to plot in. Default: None, one will be
            created if not provided.
        """
        import matplotlib.pyplot as plt
        from astropy.visualization.wcsaxes import WCSAxes
        if fig is None:
            fig = plt.figure(figsize=(10, 10))
        if ax is None:
            wcs = construct_wcs(center[0], center[1])
            ax = WCSAxes(fig, [0.15, 0.1, 0.8, 0.8], wcs=wcs)
            fig.add_axes(ax)
        else:
            wcs = ax.wcs
        if emin is None:
            emin = self.energy.value.min()
        else:
            emin = parse_value(emin, "keV")
        if emax is None:
            emax = self.energy.value.max()
        else:
            emax = parse_value(emax, "keV")
        idxs = np.logical_and(self.energy.value >= emin, self.energy.value <= emax)
        ra = self.ra[idxs][::stride].value
        dec = self.dec[idxs][::stride].value
        x, y = wcs.wcs_world2pix(ra, dec, 1)
        ax.scatter(x, y, s=s, c=c, marker=marker, label=label, **kwargs)
        x0, y0 = wcs.wcs_world2pix(center[0], center[1], 1)
        width = parse_value(width, "arcmin")*60.0
        ax.set_xlim(x0-0.5*width, x0+0.5*width)
        ax.set_ylim(y0-0.5*width, y0+0.5*width)
        ax.set_xlabel("RA")
        ax.set_ylabel("Dec")
        ax.tick_params(axis='both', labelsize=fontsize)
        return fig, ax
Ejemplo n.º 24
0
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from astropy.io import fits
import pyregion

# read in the image
xray_name = "pspc_skyview.fits"
f_xray = fits.open(xray_name)

try:
    from astropy.wcs import WCS
    from astropy.visualization.wcsaxes import WCSAxes

    wcs = WCS(f_xray[0].header)
    fig = plt.figure()
    ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], wcs=wcs)
    fig.add_axes(ax)
except ImportError:
    ax = plt.subplot(111)

ax.imshow(f_xray[0].data, cmap=cm.gray, vmin=0., vmax=0.00038, origin="lower")

reg_name = "test.reg"
r = pyregion.open(reg_name).as_imagecoord(header=f_xray[0].header)

from pyregion.mpl_helper import properties_func_default


# Use custom function for patch attribute
def fixed_color(shape, saved_attrs):
    attr_list, attr_dict = saved_attrs
Ejemplo n.º 25
0
    def test_rcparams(self):

        # Test custom rcParams

        with rc_context({
                'axes.labelcolor': 'purple',
                'axes.labelsize': 14,
                'axes.labelweight': 'bold',
                'axes.linewidth': 3,
                'axes.facecolor': '0.5',
                'axes.edgecolor': 'green',
                'xtick.color': 'red',
                'xtick.labelsize': 8,
                'xtick.direction': 'in',
                'xtick.minor.visible': True,
                'xtick.minor.size': 5,
                'xtick.major.size': 20,
                'xtick.major.width': 3,
                'xtick.major.pad': 10,
                'grid.color': 'blue',
                'grid.linestyle': ':',
                'grid.linewidth': 1,
                'grid.alpha': 0.5
        }):

            fig = plt.figure(figsize=(6, 6))
            ax = WCSAxes(fig, [0.15, 0.1, 0.7, 0.7], wcs=None)
            fig.add_axes(ax)
            ax.set_xlim(-0.5, 2)
            ax.set_ylim(-0.5, 2)
            ax.grid()
            ax.set_xlabel('X label')
            ax.set_ylabel('Y label')
            ax.coords[0].set_ticklabel(exclude_overlapping=True)
            ax.coords[1].set_ticklabel(exclude_overlapping=True)
            return fig
Ejemplo n.º 26
0
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from astropy.io import fits
import pyregion

# read in the image
xray_name = "pspc_skyview.fits"
f_xray = fits.open(xray_name)

try:
    from astropy.wcs import WCS
    from astropy.visualization.wcsaxes import WCSAxes

    wcs = WCS(f_xray[0].header)
    fig = plt.figure()
    ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], wcs=wcs)
    fig.add_axes(ax)
except ImportError:
    ax = plt.subplot(111)

ax.imshow(f_xray[0].data, cmap=cm.gray, vmin=0., vmax=0.00038, origin="lower")

reg_name = "test.reg"
r = pyregion.open(reg_name).as_imagecoord(f_xray[0].header)

patch_list, text_list = r.get_mpl_patches_texts()
for p in patch_list:
    ax.add_patch(p)
for t in text_list:
    ax.add_artist(t)
Ejemplo n.º 27
0
import pyregion

region_list = [
    "test_text.reg",
    "test_context.reg",
]

# Create figure
fig = plt.figure(figsize=(8, 4))

# Parse WCS information
header = Header.fromtextfile('sample_fits01.header')
wcs = WCS(header)

# Create axes
ax1 = WCSAxes(fig, [0.1, 0.1, 0.4, 0.8], wcs=wcs)
fig.add_axes(ax1)
ax2 = WCSAxes(fig, [0.5, 0.1, 0.4, 0.8], wcs=wcs)
fig.add_axes(ax2)

# Hide labels on y axis
ax2.coords[1].set_ticklabel_position('')

for ax, reg_name in zip([ax1, ax2], region_list):

    ax.set_xlim(300, 1300)
    ax.set_ylim(300, 1300)
    ax.set_aspect(1)

    r = pyregion.open(reg_name).as_imagecoord(header)
Ejemplo n.º 28
0
    def test_update_clip_path_rectangular(self, tmpdir):

        fig = plt.figure()
        ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], aspect='equal')

        fig.add_axes(ax)

        ax.set_xlim(0., 2.)
        ax.set_ylim(0., 2.)

        # Force drawing, which freezes the clip path returned by WCSAxes
        fig.savefig(tmpdir.join('nothing').strpath)

        ax.imshow(np.zeros((12, 4)))

        ax.set_xlim(-0.5, 3.5)
        ax.set_ylim(-0.5, 11.5)

        return fig
Ejemplo n.º 29
0
@pytest.mark.parametrize("test_input, test_kwargs, expected_error", [
    (cube[0, 0], {"axes_coordinates": np.arange(10, 10 + cube_unit[0, 0].data.shape[0]),
                  "axes_units": u.C}, TypeError),
    (cube[0, 0], {"data_unit": u.C}, TypeError)
])
def test_cube_plot_1D_errors(test_input, test_kwargs, expected_error):
    with pytest.raises(expected_error):
        output = test_input.plot(**test_kwargs)


@pytest.mark.parametrize("test_input, test_kwargs, expected_values", [
    (cube[0], {},
     (np.ma.masked_array(cube[0].data, cube[0].mask), "time [min]", "em.wl [m]",
      (-0.5, 3.5, -0.5, 2.5))),

    (cube_spatial, {'axes': WCSAxes(plt.figure(), (0, 0, 1, 1), wcs=cube_spatial.wcs)},
     (cube_spatial.data,
      "custom:pos.helioprojective.lat [deg]",
      "custom:pos.helioprojective.lon [deg]",
      (-0.5, 3.5, -0.5, 2.5))),

    (cube[0], {"axes_coordinates": ["bye", None], "axes_units": [None, u.cm]},
     (np.ma.masked_array(cube[0].data, cube[0].mask), "bye [m]", "em.wl [cm]",
      (0.0, 3.0, 2e-9, 6e-9))),

    (cube[0], {"axes_coordinates": [np.arange(10, 10 + cube[0].data.shape[1]),
                                    u.Quantity(np.arange(10, 10 + cube[0].data.shape[0]), unit=u.m)],
               "axes_units": [None, u.cm]},
     (np.ma.masked_array(cube[0].data, cube[0].mask), " [None]", " [cm]", (10, 13, 1000, 1200))),

    (cube[0], {"axes_coordinates": [np.arange(10, 10 + cube[0].data.shape[1]),