Ejemplo n.º 1
0
def display_isophote(img,
                     ell,
                     iso_color='orangered',
                     zoom=None,
                     **display_kwargs):
    """Visualize the isophotes."""
    fig = plt.figure(figsize=(12, 12))
    fig.subplots_adjust(left=0.0,
                        right=1.0,
                        bottom=0.0,
                        top=1.0,
                        wspace=0.00,
                        hspace=0.00)
    gs = gridspec.GridSpec(2, 2)
    gs.update(wspace=0.0, hspace=0.00)

    # Whole central galaxy: Step 2
    ax1 = fig.add_subplot(gs[0])
    ax1.yaxis.set_major_formatter(NullFormatter())
    ax1.xaxis.set_major_formatter(NullFormatter())

    if zoom is not None:
        x_img, y_img = img.shape[0], img.shape[1]
        x_size, y_size = int(x_img / zoom), int(y_img / zoom)
        x_off, y_off = int((x_img - x_size) / 2), int((y_img - y_size) / 2)
        img = img[x_off:x_off + x_size, y_off:y_off + y_size]
    else:
        x_off, y_off = 0, 0

    ax1 = display_single(img, ax=ax1, scale_bar=False, **display_kwargs)

    for k, iso in enumerate(ell):
        if k % 2 == 0:
            e = Ellipse(xy=(iso['x0'] - x_off, iso['y0'] - y_off),
                        height=iso['sma'] * 2.0,
                        width=iso['sma'] * 2.0 * (1.0 - iso['ell']),
                        angle=iso['pa'])
            e.set_facecolor('none')
            e.set_edgecolor(iso_color)
            e.set_alpha(0.9)
            e.set_linewidth(2.0)
            ax1.add_artist(e)
    ax1.set_aspect('equal')
Ejemplo n.º 2
0
Archivo: visual.py Proyecto: j-dr/riker
def overplot_ellipse(ell_plot, zmin=3.5, zmax=10.5):
    """Overplot the elliptical isophotes on the stellar mass maps.

    Parameters
    ----------
    ell_plot : dict
        A dictionary that summarizes all necessary data for visualizing Ellipse profiles.
    zmin : float, optional
        Minimum log10(Mass) value used to show the stellar mass map. Default: 3.5
    zmax : float, optional
        Maximum log10(Mass) value used to show the stellar mass map. Default: 10.5

    """
    # Setup the figure
    fig = plt.figure(figsize=(10, 10))
    fig.subplots_adjust(left=0.005,
                        right=0.995,
                        bottom=0.005,
                        top=0.995,
                        wspace=0.00,
                        hspace=0.00)

    # Build the grid
    gs = GridSpec(2, 2)
    gs.update(wspace=0.0, hspace=0.00)

    # Central galaxy: step 2
    ax1 = fig.add_subplot(gs[0])
    ax1.yaxis.set_major_formatter(NullFormatter())
    ax1.xaxis.set_major_formatter(NullFormatter())
    ax1 = display_single(ell_plot['mass_gal'],
                         ax=ax1,
                         stretch='log10',
                         zmin=zmin,
                         zmax=zmax,
                         cmap=IMG_CMAP,
                         no_negative=True,
                         color_bar=True,
                         scale_bar=False,
                         color_bar_color='k')

    if ell_plot['ell_gal_2'] is not None:
        for k, iso in enumerate(ell_plot['ell_gal_2']):
            if k % 3 == 0 and iso['sma'] >= 6.0:
                e = Ellipse(xy=(iso['x0'], iso['y0']),
                            height=iso['sma'] * 2.0,
                            width=iso['sma'] * 2.0 * (1.0 - iso['ell']),
                            angle=iso['pa'])
                e.set_facecolor('none')
                e.set_edgecolor('k')
                e.set_alpha(0.6)
                e.set_linewidth(2.0)
                ax1.add_artist(e)
        ax1.set_aspect('equal')

    _ = ax1.text(0.05,
                 0.06,
                 r'$\rm Total$',
                 fontsize=25,
                 transform=ax1.transAxes,
                 bbox=dict(facecolor='w', edgecolor='none', alpha=0.8))

    # Central galaxy: step 3
    ax2 = fig.add_subplot(gs[1])
    ax2.yaxis.set_major_formatter(NullFormatter())
    ax2.xaxis.set_major_formatter(NullFormatter())
    ax2 = display_single(ell_plot['mass_gal'],
                         ax=ax2,
                         stretch='log10',
                         zmin=zmin,
                         zmax=zmax,
                         cmap=IMG_CMAP,
                         no_negative=True,
                         color_bar=False,
                         scale_bar=True,
                         pixel_scale=1.,
                         physical_scale=ell_plot['pix'],
                         scale_bar_loc='right',
                         scale_bar_length=50.,
                         scale_bar_color='k',
                         scale_bar_y_offset=1.3)

    ax2.text(0.5,
             0.92,
             r'$\mathrm{ID}: %d\ \ \log M_{\star}: %5.2f$' %
             (ell_plot['catsh_id'], ell_plot['logms']),
             fontsize=21,
             transform=ax2.transAxes,
             horizontalalignment='center',
             verticalalignment='center',
             bbox=dict(facecolor='w', edgecolor='none', alpha=0.5))

    # Show the average isophotal shape
    if ell_plot['ell_ins_3'] is not None:
        n_iso = len(ell_plot['ell_ins_3'])
        if n_iso > 15:
            idx_use = n_iso - 6
        else:
            idx_use = n_iso - 1
        for k, iso in enumerate(ell_plot['ell_ins_3']):
            if k == idx_use:
                e = Ellipse(xy=(iso['x0'], iso['y0']),
                            height=iso['sma'] * 2.0,
                            width=iso['sma'] * 2.0 * (1.0 - iso['ell']),
                            angle=iso['pa'])
                e.set_facecolor('none')
                e.set_edgecolor('k')
                e.set_linestyle('--')
                e.set_alpha(0.8)
                e.set_linewidth(2.5)
                ax2.add_artist(e)
        ax2.set_aspect('equal')

    _ = ax2.text(0.05,
                 0.06,
                 r'$\rm Total$',
                 fontsize=25,
                 transform=ax2.transAxes,
                 bbox=dict(facecolor='w', edgecolor='none', alpha=0.8))

    # In situ component: step 2
    ax3 = fig.add_subplot(gs[2])
    ax3.yaxis.set_major_formatter(NullFormatter())
    ax3.xaxis.set_major_formatter(NullFormatter())
    ax3 = display_single(ell_plot['mass_ins'],
                         ax=ax3,
                         stretch='log10',
                         zmin=zmin,
                         zmax=zmax,
                         cmap=IMG_CMAP,
                         no_negative=True,
                         color_bar=False,
                         scale_bar=False)

    if ell_plot['ell_ins_2'] is not None:
        for k, iso in enumerate(ell_plot['ell_ins_2']):
            if k % 3 == 0 and iso['sma'] >= 6.0:
                e = Ellipse(xy=(iso['x0'], iso['y0']),
                            height=iso['sma'] * 2.0,
                            width=iso['sma'] * 2.0 * (1.0 - iso['ell']),
                            angle=iso['pa'])
                e.set_facecolor('none')
                e.set_edgecolor('orangered')
                e.set_alpha(0.9)
                e.set_linewidth(2.0)
                ax3.add_artist(e)
        ax3.set_aspect('equal')

    _ = ax3.text(0.05,
                 0.06,
                 r'$\rm In\ situ$',
                 fontsize=25,
                 transform=ax3.transAxes,
                 bbox=dict(facecolor='w', edgecolor='none', alpha=0.8))

    # Ex situ component: step 2
    ax4 = fig.add_subplot(gs[3])
    ax4.yaxis.set_major_formatter(NullFormatter())
    ax4.xaxis.set_major_formatter(NullFormatter())
    ax4 = display_single(ell_plot['mass_exs'],
                         ax=ax4,
                         stretch='log10',
                         zmin=zmin,
                         zmax=zmax,
                         cmap=IMG_CMAP,
                         no_negative=True,
                         color_bar=False,
                         scale_bar=False)

    if ell_plot['ell_exs_2'] is not None:
        for k, iso in enumerate(ell_plot['ell_exs_2']):
            if k % 3 == 0 and iso['sma'] >= 6.0:
                e = Ellipse(xy=(iso['x0'], iso['y0']),
                            height=iso['sma'] * 2.0,
                            width=iso['sma'] * 2.0 * (1.0 - iso['ell']),
                            angle=iso['pa'])
                e.set_facecolor('none')
                e.set_edgecolor('steelblue')
                e.set_alpha(0.9)
                e.set_linewidth(2.0)
                ax4.add_artist(e)
        ax4.set_aspect('equal')

    _ = ax4.text(0.05,
                 0.06,
                 r'$\rm Ex\ situ$',
                 fontsize=25,
                 transform=ax4.transAxes,
                 bbox=dict(facecolor='w', edgecolor='none', alpha=0.8))

    return fig
Ejemplo n.º 3
0
Archivo: visual.py Proyecto: j-dr/riker
def show_maps(maps, aper, cid=None, logms=None, figsize=(15, 15)):
    """Visualize the stellar mass, age, and metallicity maps.

    Parameters
    ----------
    maps : dict
        Dictionary that contains all stellar mass, age, and metallicity maps.
    aper : dict
        Dictionary that contains basic shape information of the galaxy.
    cid : int, optional
        `catsh_id`, sub-halo ID in the simulation. Used to identify galaxy.
        Default: None
    logms : float, optional
        Stellar mass in log10 unit. Default: None.
    figsize : tuple, optional
        Size of the 3x3 figure. Default: (15, 15)

    """
    # Setup the figure and grid of axes
    fig_sum = plt.figure(figsize=figsize, constrained_layout=False)
    grid_sum = fig_sum.add_gridspec(3, 3, wspace=0.0, hspace=0.0)
    fig_sum.subplots_adjust(left=0.005,
                            right=0.995,
                            bottom=0.005,
                            top=0.995,
                            wspace=0.00,
                            hspace=0.00)

    # List of the maps need to be plot
    list_maps = [
        'mass_gal', 'mass_ins', 'mass_exs', 'age_gal', 'age_ins', 'age_exs',
        'met_gal', 'met_ins', 'met_exs'
    ]

    for ii, name in enumerate(list_maps):
        ax = fig_sum.add_subplot(grid_sum[ii])
        if 'mass' in name:
            if ii % 3 == 0:
                _ = display_single(maps[name],
                                   ax=ax,
                                   stretch='log10',
                                   zmin=6.0,
                                   zmax=10.5,
                                   color_bar=True,
                                   scale_bar=False,
                                   no_negative=True,
                                   color_bar_height='5%',
                                   color_bar_width='85%',
                                   color_bar_fontsize=20,
                                   cmap=IMG_CMAP,
                                   color_bar_color='k')
                _ = ax.text(0.05,
                            0.06,
                            r'$\log[M_{\star}/M_{\odot}]$',
                            fontsize=25,
                            transform=ax.transAxes,
                            bbox=dict(facecolor='w',
                                      edgecolor='none',
                                      alpha=0.8))
            else:
                _ = display_single(maps[name],
                                   ax=ax,
                                   stretch='log10',
                                   zmin=6.0,
                                   zmax=10.5,
                                   color_bar=False,
                                   scale_bar=False,
                                   no_negative=True,
                                   cmap=IMG_CMAP,
                                   color_bar_color='k')
            if ii == 0:
                # Label the center of the galaxy
                ax.scatter(aper['x'],
                           aper['y'],
                           marker='+',
                           s=200,
                           c='orangered',
                           linewidth=2.0,
                           alpha=0.6)
                # Show the isophote shape
                e = Ellipse(xy=(aper['x'], aper['y']),
                            height=80.0 * aper['ba'],
                            width=80.0,
                            angle=aper['pa'])
                e.set_facecolor('none')
                e.set_edgecolor('orangered')
                e.set_alpha(0.5)
                e.set_linewidth(2.0)
                ax.add_artist(e)
                # Central
                ax.text(0.75,
                        0.06,
                        r'$\rm Total$',
                        fontsize=25,
                        transform=ax.transAxes)
            # Put the ID
            if ii == 1 and cid is not None and logms is not None:
                ax.text(0.5,
                        0.88,
                        r'$\mathrm{ID}: %d\ \ \log M_{\star}: %5.2f$' %
                        (cid, logms),
                        fontsize=25,
                        transform=ax.transAxes,
                        horizontalalignment='center',
                        bbox=dict(facecolor='w', edgecolor='none', alpha=0.7))
                ax.text(0.75,
                        0.06,
                        r'$\rm In\ situ$',
                        fontsize=25,
                        transform=ax.transAxes)
            if ii == 2:
                ax.text(0.75,
                        0.06,
                        r'$\rm Ex\ situ$',
                        fontsize=25,
                        transform=ax.transAxes)
        if 'age' in name:
            if ii % 3 == 0:
                _ = display_single(maps[name],
                                   ax=ax,
                                   stretch='linear',
                                   zmin=1.0,
                                   zmax=8.5,
                                   color_bar=True,
                                   scale_bar=False,
                                   no_negative=True,
                                   color_bar_height='5%',
                                   color_bar_width='85%',
                                   color_bar_fontsize=20,
                                   cmap=IMG_CMAP,
                                   color_bar_color='k')
                _ = ax.text(0.06,
                            0.06,
                            r'$\rm Age/Gyr$',
                            fontsize=25,
                            transform=ax.transAxes,
                            bbox=dict(facecolor='w',
                                      edgecolor='none',
                                      alpha=0.8))
            else:
                _ = display_single(maps[name],
                                   ax=ax,
                                   stretch='linear',
                                   zmin=1.0,
                                   zmax=8.5,
                                   color_bar=False,
                                   scale_bar=False,
                                   no_negative=True,
                                   cmap=IMG_CMAP,
                                   color_bar_color='k')
        if 'met' in name:
            if ii % 3 == 0:
                _ = display_single(maps[name] / Z_SUN,
                                   ax=ax,
                                   stretch='log10',
                                   zmin=-0.6,
                                   zmax=0.9,
                                   color_bar=True,
                                   scale_bar=False,
                                   no_negative=True,
                                   color_bar_height='5%',
                                   color_bar_width='85%',
                                   color_bar_fontsize=20,
                                   cmap=IMG_CMAP,
                                   color_bar_color='k')
                _ = ax.text(0.06,
                            0.06,
                            r'$\log[Z_{\star}/Z_{\odot}]$',
                            fontsize=25,
                            transform=ax.transAxes,
                            bbox=dict(facecolor='w',
                                      edgecolor='none',
                                      alpha=0.8))
            else:
                _ = display_single(maps[name] / Z_SUN,
                                   ax=ax,
                                   stretch='log10',
                                   zmin=-0.6,
                                   zmax=0.9,
                                   color_bar=False,
                                   scale_bar=False,
                                   no_negative=True,
                                   cmap=IMG_CMAP,
                                   color_bar_color='k')

    return fig_sum
Ejemplo n.º 4
0
def image_gaia_stars(image,
                     wcs,
                     radius=None,
                     center=None,
                     pixel=0.168,
                     mask_a=694.7,
                     mask_b=4.04,
                     verbose=False,
                     visual=False,
                     size_buffer=1.4,
                     tap_url=None,
                     img_size=(8, 8)):
    """Search for bright stars using GAIA catalog.

    TODO:
        Should be absorbed by the object for image later.

    TODO:
        Should have a version that just uses the local catalog.
    """
    # Central coordinate
    if center is None:
        ra_cen, dec_cen = wcs.wcs_pix2world(image.shape[0] / 2,
                                            image.shape[1] / 2, 0)
        img_cen_ra_dec = SkyCoord(ra_cen,
                                  dec_cen,
                                  unit=('deg', 'deg'),
                                  frame='icrs')
        if verbose:
            print("# The center of the search: RA={:9.5f}, DEC={:9.5f}".format(
                ra_cen, dec_cen))
    else:
        if not isinstance(center, SkyCoord):
            raise TypeError(
                "# The center coordinate should be a SkyCoord object")
        img_cen_ra_dec = center
        if verbose:
            print("# The center of the search: RA={:9.5f}, DEC={:9.5f}".format(
                center.ra, center.dec))

    # Width and height of the search box
    if radius is None:
        img_search_x = Quantity(pixel * (image.shape)[0] * size_buffer,
                                u.arcsec)
        img_search_y = Quantity(pixel * (image.shape)[1] * size_buffer,
                                u.arcsec)
        if verbose:
            print("# The width of the search: {:7.1f}".format(img_search_x))
            print("# The height of the search: {:7.1f}".format(img_search_y))
    else:
        if not isinstance(radius, Quantity):
            raise TypeError(
                "# Searching radius needs to be an Astropy Quantity.")
        if verbose:
            print("# The searching radius is: {:7.2f}".format(radius))

    # Search for stars
    if tap_url is not None:
        with suppress_stdout():
            from astroquery.gaia import TapPlus, GaiaClass
            Gaia = GaiaClass(TapPlus(url=tap_url))

            if radius is not None:
                gaia_results = Gaia.query_object_async(
                    coordinate=img_cen_ra_dec, radius=radius, verbose=verbose)
            else:
                gaia_results = Gaia.query_object_async(
                    coordinate=img_cen_ra_dec,
                    width=img_search_x,
                    height=img_search_y,
                    verbose=verbose)
    else:
        with suppress_stdout():
            from astroquery.gaia import Gaia

            if radius is not None:
                gaia_results = Gaia.query_object_async(
                    coordinate=img_cen_ra_dec, radius=radius, verbose=verbose)
            else:
                gaia_results = Gaia.query_object_async(
                    coordinate=img_cen_ra_dec,
                    width=img_search_x,
                    height=img_search_y,
                    verbose=verbose)

    if gaia_results:
        # Convert the (RA, Dec) of stars into pixel coordinate
        ra_gaia = np.asarray(gaia_results['ra'])
        dec_gaia = np.asarray(gaia_results['dec'])
        x_gaia, y_gaia = wcs.wcs_world2pix(ra_gaia, dec_gaia, 0)

        # Generate mask for each star
        rmask_gaia_arcsec = mask_a * np.exp(
            -gaia_results['phot_g_mean_mag'] / mask_b)

        # Update the catalog
        gaia_results.add_column(Column(data=x_gaia, name='x_pix'))
        gaia_results.add_column(Column(data=y_gaia, name='y_pix'))
        gaia_results.add_column(
            Column(data=rmask_gaia_arcsec, name='rmask_arcsec'))

        if visual:
            fig = plt.figure(figsize=img_size)
            ax1 = fig.add_subplot(111)

            ax1 = display_single(image, ax=ax1)
            # Plot an ellipse for each object
            for star in gaia_results:
                smask = mpl_ellip(xy=(star['x_pix'], star['y_pix']),
                                  width=(2.0 * star['rmask_arcsec'] / pixel),
                                  height=(2.0 * star['rmask_arcsec'] / pixel),
                                  angle=0.0)
                smask.set_facecolor('coral')
                smask.set_edgecolor('coral')
                smask.set_alpha(0.3)
                ax1.add_artist(smask)

            # Show stars
            ax1.scatter(gaia_results['x_pix'],
                        gaia_results['y_pix'],
                        c='orangered',
                        s=100,
                        alpha=0.9,
                        marker='+')

            ax1.set_xlim(0, image.shape[0])
            ax1.set_ylim(0, image.shape[1])

            return gaia_results, fig

        return gaia_results

    return None
Ejemplo n.º 5
0
def evaluate_sky(img, sigma=1.5, radius=10, pixel_scale=0.168, central_mask_radius=7.0, 
                 threshold=0.005, deblend_cont=0.001, deblend_nthresh=20, 
                 clean_param=1.0, show_fig=True, show_hist=True, f_factor=None):
    '''Evaluate the mean sky value.
    Parameters:
    ----------
    img: 2-D numpy array, the input image
    show_fig: bool. If True, it will show you the masked sky image.
    show_hist: bool. If True, it will show you the histogram of the sky value.
    
    Returns:
    -------
    median: median of background pixels, in original unit
    std: standard deviation, in original unit
    '''
    import sep
    import copy 
    from slug.imutils import extract_obj, make_binary_mask
    from astropy.convolution import convolve, Gaussian2DKernel
    b = 35  # Box size
    f = 5   # Filter width

    bkg = sep.Background(img, maskthresh=0, bw=b, bh=b, fw=f, fh=f)
    # first time
    objects, segmap = extract_obj(img - bkg.globalback, b=35, f=5, sigma=sigma,
                                    minarea=20, pixel_scale=pixel_scale,
                                    deblend_nthresh=deblend_nthresh, deblend_cont=deblend_cont,
                                    clean_param=clean_param, show_fig=False)
    
    seg_sky = copy.deepcopy(segmap)
    seg_sky[segmap > 0] = 1
    seg_sky = seg_sky.astype(bool)
    # Blow up the mask
    for obj in objects:
        sep.mask_ellipse(seg_sky, obj['x'], obj['y'], obj['a'], obj['b'], obj['theta'], r=radius)
    bkg_mask_1 = seg_sky
    
    data = copy.deepcopy(img - bkg.globalback)
    data[bkg_mask_1 == 1] = 0

    # Second time
    obj_lthre, seg_lthre = extract_obj(data, b=35, f=5, sigma=sigma + 1,
                                       minarea=5, pixel_scale=pixel_scale,
                                       deblend_nthresh=deblend_nthresh, deblend_cont=deblend_cont,
                                       clean_param=clean_param, show_fig=False)
    seg_sky = copy.deepcopy(seg_lthre)
    seg_sky[seg_lthre > 0] = 1
    seg_sky = seg_sky.astype(bool)
    # Blow up the mask
    for obj in obj_lthre:
        sep.mask_ellipse(seg_sky, obj['x'], obj['y'], obj['a'], obj['b'], obj['theta'], r=radius/2)
    bkg_mask_2 = seg_sky
    
    bkg_mask = (bkg_mask_1 + bkg_mask_2).astype(bool)
    
    cen_obj = objects[segmap[int(bkg_mask.shape[0] / 2.), int(bkg_mask.shape[1] / 2.)] - 1]
    fraction_radius = sep.flux_radius(img, cen_obj['x'], cen_obj['y'], 10*cen_obj['a'], 0.5)[0]
    
    ba = np.divide(cen_obj['b'], cen_obj['a'])
    
    if fraction_radius < int(bkg_mask.shape[0] / 8.):
        sep.mask_ellipse(bkg_mask, cen_obj['x'], cen_obj['y'], fraction_radius, fraction_radius * ba,
                        cen_obj['theta'], r=central_mask_radius)
    elif fraction_radius < int(bkg_mask.shape[0] / 4.):
        sep.mask_ellipse(bkg_mask, cen_obj['x'], cen_obj['y'], fraction_radius, fraction_radius * ba,
                        cen_obj['theta'], r=1.2)
    
    # Estimate sky from histogram of binned image
    import copy
    from scipy import stats
    from astropy.stats import sigma_clip
    from astropy.nddata import block_reduce
    data = copy.deepcopy(img)
    data[bkg_mask] = np.nan
    if f_factor is None:
        f_factor = round(6 / pixel_scale)
    rebin = block_reduce(data, f_factor)
    sample = rebin.flatten()
    if show_fig:
        display_single(rebin)
        plt.savefig('./{}-bkg.png'.format(np.random.randint(1000)), dpi=100, bbox_inches='tight')
    
    temp = sigma_clip(sample)
    sample = temp.data[~temp.mask]

    kde = stats.gaussian_kde(sample)
    print(f_factor)
    mean = np.nanmean(sample) / f_factor**2
    median = np.nanmedian(sample) / f_factor**2
    std = np.nanstd(sample, ddof=1) / f_factor / np.sqrt(len(sample))

    xlim = np.std(sample, ddof=1) * 7
    x = np.linspace(-xlim + np.median(sample), xlim + np.median(sample), 100)
    offset = x[np.argmax(kde.evaluate(x))] / f_factor**2
    
    print('mean', mean)
    print('median', median)
    print('std', std)

    bkg_global = sep.Background(img, 
                                mask=bkg_mask, maskthresh=0,
                                bw=f_factor, bh=f_factor, 
                                fw=f_factor/2, fh=f_factor/2)
    print("#SEP sky: Mean Sky / RMS Sky = %10.5f / %10.5f" % (bkg_global.globalback, bkg_global.globalrms))

    if show_hist:
        fig, ax = plt.subplots(figsize=(8,6))

        ax.plot(x, kde.evaluate(x), linestyle='dashed', c='black', lw=2,
                label='KDE')
        ax.hist(sample, bins=x, normed=1);
        ax.legend(loc='best', frameon=False, fontsize=20)

        ax.set_xlabel('Pixel Value', fontsize=20)
        ax.set_ylabel('Normed Number', fontsize=20)
        ax.tick_params(labelsize=20)
        ylim = ax.get_ylim()
        ax.text(-0.1 * f_factor + np.median(sample), 0.9 * (ylim[1] - ylim[0]) + ylim[0], 
                r'$\mathrm{offset}='+str(round(offset, 6))+'$', fontsize=20)
        ax.text(-0.1 * f_factor + np.median(sample), 0.8 * (ylim[1] - ylim[0]) + ylim[0],
                r'$\mathrm{median}='+str(round(median, 6))+'$', fontsize=20)
        ax.text(-0.1 * f_factor + np.median(sample), 0.7 * (ylim[1] - ylim[0]) + ylim[0],
                r'$\mathrm{std}='+str(round(std, 6))+'$', fontsize=20)
        plt.vlines(np.median(sample), 0, ylim[1], linestyle='--')

    return median, std, sample
Ejemplo n.º 6
0
def h5_gen_mock_image_double(h5_path, pixel_scale, band, 
    i_gal_flux, i_gal_rh, i_gal_q, i_sersic_index, i_gal_beta, 
    i_psf_rh, groupname=None):
    '''
    Generate mock images.

    Parameters:
    -----------
    h5_path: string, the path of your h5 file.
    pixel_scale: float, in the unit of arcsec/pixel.
    band: string, such as 'r-band'.
    i_gal-flux: float, input galsim flux of the fake galaxy.
    i_gal_rh: float, input half-light-radius of the fake galaxy.
    i_gal_q: float, input b/a.
    i_sersic_index: float, input sersic index.
    i_gal_beta: float, input position angle (in degrees).
    i_psf_rh: float, the half-light-radius of PSF.
    groupname: string, such as 'model-0'.

    '''
    import h5py
    import galsim
    from .h5file import h5_rewrite_dataset
    f = h5py.File(h5_path, 'r+')
    field = f['Background'][band]['image'][:]
    w = wcs.WCS(f['Background'][band]['image_header'].value)
    print ('Size (in pixel):', [field.shape[1], field.shape[0]])
    print ('Angular size (in arcsec):', [
        field.shape[1] * pixel_scale, field.shape[0] * pixel_scale
    ])
    print ('The center of this image:', [field.shape[1] / 2, field.shape[0] / 2])
   
    # Define sersic galaxy
    gal1 = galsim.Sersic(i_sersic_index[0], half_light_radius=i_gal_rh[0], flux=i_gal_flux[0])
    gal2 = galsim.Sersic(i_sersic_index[1], half_light_radius=i_gal_rh[1], flux=i_gal_flux[1])
    gal = gal1 + gal2
    # Shear the galaxy by some value.
    # q, beta      Axis ratio and position angle: q = b/a, 0 < q < 1
    gal_shape = galsim.Shear(q=i_gal_q, beta=i_gal_beta * galsim.degrees)
    gal = gal.shear(gal_shape)
    # Define the PSF profile
    #psf = galsim.Moffat(beta=psf_beta, flux=1., half_light_radius=psf_rh)
    psf = galsim.Gaussian(sigma=i_psf_rh, flux=1.)
    # Convolve galaxy with PSF
    final = galsim.Convolve([gal, psf])
    # Draw the image with a particular pixel scale.
    image = final.drawImage(scale=pixel_scale, nx=field.shape[1], ny=field.shape[0])
    
    if groupname is None:
        groupname = 'n' + str(i_sersic_index)
    
    #g1 = f['ModelImage'][band].create_group(groupname)
    #g1.create_dataset('modelimage', data=image.array)
    g = f['ModelImage'][band]
    if not any([keys == groupname for keys in g.keys()]):
        g1 = g.create_group(groupname)
    else:
        g1 = g[groupname]
    h5_rewrite_dataset(g1, 'modelimage', image.array)

    # Generate mock image
    mock_img = image.array + field

    #g2 = f['MockImage'][band].create_group(groupname)
    #g2.create_dataset('mockimage', data=mock_img)
    g = f['MockImage'][band]
    if not any([keys == groupname for keys in g.keys()]):
        g2 = g.create_group(groupname)
    else:
        g2 = g[groupname]
    h5_rewrite_dataset(g2, 'mockimage', mock_img)

    # Plot fake galaxy and the composite mock image
    fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(12, 6))
    display_single(image.array, ax=ax1, scale_bar_length=10)
    display_single(mock_img, scale_bar_length=10, ax=ax2)
    plt.show(block=False)
    plt.subplots_adjust(wspace=0.)
    f.close()
Ejemplo n.º 7
0
def tractor_iteration(obj_cat, w, img_data, invvar, psf_obj, pixel_scale, shape_method='manual', 
                      kfold=4, first_num=50, fig_name=None):
    '''
    Run tractor iteratively.

    Parameters:
    -----------
    obj_cat: objects catalogue.
    w: wcs object.
    img_data: 2-D np.array, image.
    invvar: 2-D np.array, inverse variance matrix of the image.
    psf_obj: PSF object, defined by tractor.psf.PixelizedPSF() class.
    pixel_scale: float, pixel scale in unit arcsec/pixel.
    shape_method: if 'manual', then adopt manually measured shape. If 'decals', then adopt DECaLS shape from tractor files.
    kfold: int, iteration time.
    first_num: how many objects will be fit in the first run.
    fig_name: string, if not None, it will save the tractor subtracted image to the given path.

    Returns:
    -----------
    sources: list, containing tractor model sources.
    trac_obj: optimized tractor object after many iterations.
    '''
    from tractor import NullWCS, NullPhotoCal, ConstantSky
    from tractor.galaxy import GalaxyShape, DevGalaxy, ExpGalaxy, CompositeGalaxy
    from tractor.psf import Flux, PixPos, PointSource, PixelizedPSF, Image, Tractor
    from tractor.ellipses import EllipseE

    step = int((len(obj_cat) - first_num) / (kfold - 1))
    for i in range(kfold):
        if i == 0:
            obj_small_cat = obj_cat[:first_num]
            sources = add_tractor_sources(obj_small_cat, None, w, shape_method=shape_method)
        else:
            obj_small_cat = obj_cat[first_num + step * (i - 1) : first_num + step * (i)]
            sources = add_tractor_sources(obj_small_cat, sources, w, shape_method=shape_method)

        tim = Image(data=img_data,
                    invvar=invvar,
                    psf=psf_obj,
                    wcs=NullWCS(pixscale=pixel_scale),
                    sky=ConstantSky(0.0),
                    photocal=NullPhotoCal()
                    )
        trac_obj = Tractor([tim], sources)
        trac_mod = trac_obj.getModelImage(0, minsb=0.0)

        # Optimization
        trac_obj.freezeParam('images')
        trac_obj.optimize_loop()
        ########################
        plt.rc('font', size=20)
        if i % 2 == 1 or i == (kfold-1) :
            fig, [ax1, ax2, ax3] = plt.subplots(1, 3, figsize=(18,8))

            trac_mod_opt = trac_obj.getModelImage(0, minsb=0., srcs=sources[:])

            ax1 = display_single(img_data, ax=ax1, scale_bar=False)
            ax1.set_title('raw image')
            ax2 = display_single(trac_mod_opt, ax=ax2, scale_bar=False, contrast=0.02)
            ax2.set_title('tractor model')
            ax3 = display_single(abs(img_data - trac_mod_opt), ax=ax3, scale_bar=False, color_bar=True, contrast=0.05)
            ax3.set_title('residual')

            if i == (kfold-1):
                if fig_name is not None:
                    plt.savefig(fig_name, dpi=200, bbox_inches='tight')
                    plt.show()
                    print('The chi-square is', np.sqrt(np.mean(np.square((img_data - trac_mod_opt).flatten()))))
            else:
                plt.show()
                print('The chi-square is', np.sqrt(np.mean(np.square((img_data - trac_mod_opt).flatten()))) / np.sum(img_data)) 

        #trac_mod_opt = trac_obj.getModelImage(0, minsb=0., srcs=sources[1:])
        #ax4 = display_single(img_data - trac_mod_opt, ax=ax4, scale_bar=False, color_bar=True, contrast=0.05)
        #ax4.set_title('remain central galaxy')


    return sources, trac_obj, fig
Ejemplo n.º 8
0
def show_galfit_model(galfit_fits,
                      galfit_out,
                      root=None,
                      verbose=True,
                      vertical=False,
                      zoom=True,
                      zoom_limit=6.0,
                      show_title=True,
                      show_chi2=True,
                      zoom_size=None,
                      overplot_components=True,
                      mask_residual=True,
                      mask_file=None,
                      show_bic=True,
                      xsize=15.0,
                      ysize=5.0,
                      cmap=None,
                      axes=None,
                      show_contour=True,
                      title=None,
                      pixel=0.168,
                      **kwargs):
    """Three columns plot of the Galfit models. """
    # Read in the results fits file
    galfit_results = fits.open(galfit_fits)
    fig_name = galfit_fits.replace('.fits', '.png')

    # Colormap
    cmap = plt.get_cmap('viridis') if cmap is None else cmap

    if verbose:
        print(" ## %s ---> %s " % (galfit_fits, fig_name))

    # Set up the figure
    if axes is None:
        fig = plt.figure(figsize=(xsize, ysize))
        if not vertical:
            grid = gridspec.GridSpec(1, 3)
        else:
            grid = gridspec.GridSpec(3, 1)
            xsize, ysize = ysize, xsize
        grid.update(wspace=0.0, hspace=0.0, left=0, right=1, top=1, bottom=0)
        ax1 = plt.subplot(grid[0])
        ax2 = plt.subplot(grid[1])
        ax3 = plt.subplot(grid[2])
    else:
        # Can provide a tuple of three pre-designed axes
        ax1, ax2, ax3 = axes

    # Geometry of each component
    n_comp = galfit_out.num_components
    x_comps, y_comps, re_comps, ba_comps, pa_comps = [], [], [], [], []

    for i in range(1, n_comp + 1):
        string_comp = 'component_' + str(i)
        info_comp = getattr(galfit_out, string_comp)
        if info_comp.component_type == 'sersic':
            x_comps.append(info_comp.xc)
            y_comps.append(info_comp.yc)
            re_comps.append(info_comp.re)
            ba_comps.append(info_comp.ar)
            pa_comps.append(info_comp.pa)

    # Image size and scale
    img_ori = galfit_results[1].data
    img_mod = galfit_results[2].data
    img_res = galfit_results[3].data
    img_xsize, img_ysize = img_ori.shape
    img_xcen, img_ycen = img_xsize / 2.0, img_ysize / 2.0

    # Show mask on the residual map
    if mask_residual:
        # Allow user to provide an external mask
        if mask_file is None:
            mask_file = os.path.join(root, galfit_out.input_mask)
        if os.path.isfile(mask_file) or os.path.islink(mask_file):
            img_msk = fits.open(mask_file)[0].data
            img_msk = img_msk[np.int(galfit_out.box_x0) -
                              1:np.int(galfit_out.box_x1),
                              np.int(galfit_out.box_y0) -
                              1:np.int(galfit_out.box_y1)]
            img_res[img_msk > 0] = np.nan
        else:
            print("XXX Can not find the mask file : %s" % mask_file)
            img_msk = None
            img_res = img_res

    r_max = np.max(np.asarray(re_comps)) * zoom_limit
    if zoom_size is not None:
        r_zoom = zoom_size / 2.0
        x0, x1 = int(img_xcen - r_zoom), int(img_xcen + r_zoom)
        y0, y1 = int(img_ycen - r_zoom), int(img_ycen + r_zoom)
    elif img_xcen >= r_max and img_ycen >= r_max and zoom:
        x0, x1 = int(img_xcen - r_max), int(img_xcen + r_max)
        y0, y1 = int(img_ycen - r_max), int(img_ycen + r_max)
        print(" ## Image has been truncated to highlight the galaxy !")
    else:
        x0, x1 = 0, img_xsize - 1
        y0, y1 = 0, img_ysize - 1

    img_ori = img_ori[x0:x1, y0:y1]
    img_mod = img_mod[x0:x1, y0:y1]
    img_res = img_res[x0:x1, y0:y1]
    x_padding, y_padding = x0, y0

    x_comps = np.asarray(x_comps) - np.float(galfit_out.box_x0)
    y_comps = np.asarray(y_comps) - np.float(galfit_out.box_y0)
    re_comps = np.asarray(re_comps)
    ba_comps = np.asarray(ba_comps)
    pa_comps = np.asarray(pa_comps)
    x_comps -= x_padding
    y_comps -= y_padding

    # Show the original image
    ax1.xaxis.set_major_formatter(NullFormatter())
    ax1.yaxis.set_major_formatter(NullFormatter())
    ax1 = display_single(img_ori,
                         ax=ax1,
                         cmap=cmap,
                         pixel_scale=pixel,
                         **kwargs)

    # Overplot the contour of the component
    if overplot_components:
        try:
            for ii, r0 in enumerate(re_comps):
                x0, y0 = x_comps[ii], y_comps[ii]
                q0, pa0 = ba_comps[ii], pa_comps[ii]
                ellip_comp = Ellipse(xy=(x0, y0),
                                     width=(r0 * q0 * 2.0),
                                     height=(r0 * 2.0),
                                     angle=pa0)
                ax1.add_artist(ellip_comp)
                ellip_comp.set_clip_box(ax1.bbox)
                ellip_comp.set_alpha(1.0)
                ellip_comp.set_edgecolor(color_comps[ii])
                ellip_comp.set_facecolor('none')
                ellip_comp.set_linewidth(2.5)
        except Exception:
            print("XXX Can not highlight the components")

    # Show a tile
    if show_title and title is not None:
        str_title = ax1.text(0.50,
                             0.90,
                             r'$\mathrm{%s}$' % title,
                             fontsize=25,
                             transform=ax1.transAxes,
                             horizontalalignment='center')
        str_title.set_bbox(
            dict(facecolor='white', alpha=0.6, edgecolor='white'))

    # Show the model
    ax2.xaxis.set_major_formatter(NullFormatter())
    ax2.yaxis.set_major_formatter(NullFormatter())
    ax2 = display_single(img_mod,
                         ax=ax2,
                         cmap=cmap,
                         pixel_scale=pixel,
                         **kwargs)

    # Show contours of the model
    if show_contour:
        try:
            tam = np.size(img_mod, axis=0)
            contour_x = np.arange(tam)
            contour_y = np.arange(tam)
            ax2.contour(contour_x,
                        contour_y,
                        np.arcsinh(img_mod),
                        colors='c',
                        linewidths=1.5)
        except Exception:
            print("XXX Can not generate the Contour !")

    # Show the reduced chisq
    if show_chi2:
        ax2.text(0.06,
                 0.92,
                 r'${\chi}^2/N_{DoF} : %s$' % galfit_out.reduced_chisq,
                 fontsize=14,
                 transform=ax2.transAxes)
    if show_bic:
        aic, bic, hq = galfit_naive_aic(galfit_out)
        ax2.text(0.06,
                 0.82,
                 r'$\mathrm{BIC} : %9.3f$' % bic,
                 fontsize=14,
                 transform=ax2.transAxes)
        # ax2.text(0.06, 0.87, 'AIC : %9.3f' % aic, fontsize=14, transform=ax2.transAxes)
        # ax2.text(0.06, 0.77, 'HQ : %9.3f' % hq, fontsize=14, transform=ax2.transAxes)

    # Show the residual image
    ax3.xaxis.set_major_formatter(NullFormatter())
    ax3.yaxis.set_major_formatter(NullFormatter())
    ax3 = display_single(img_res,
                         ax=ax3,
                         cmap=cmap,
                         pixel_scale=pixel,
                         **kwargs)

    # Overplot the contour of the component
    if overplot_components:
        try:
            for ii, r0 in enumerate(re_comps):
                x0, y0 = x_comps[ii], y_comps[ii]
                q0, pa0 = ba_comps[ii], pa_comps[ii]
                ellip_comp = Ellipse(xy=(x0, y0),
                                     width=(r0 * q0 * 2.0),
                                     height=(r0 * 2.0),
                                     angle=pa0)
                ax1.add_artist(ellip_comp)
                ellip_comp.set_clip_box(ax3.bbox)
                ellip_comp.set_alpha(1.0)
                ellip_comp.set_edgecolor(color_comps[ii])
                ellip_comp.set_facecolor('none')
                ellip_comp.set_linewidth(2.5)
        except Exception:
            print("XXX Can not highlight the components")

    # Save Figure
    if axes is None:
        fig.savefig(fig_name, dpi=80)
        return fig
    else:
        return ax1, ax2, ax3