예제 #1
0
파일: image.py 프로젝트: rayadastidar/FRB
def main(pargs):
    """ Run
    """
    import warnings
    import numpy as np
    from matplotlib import pyplot as plt
    from astropy.io import fits

    from astropy import units

    from frb import frb
    from frb.figures import galaxies as ffgal
    from frb.figures import utils as ffutils

    from linetools.scripts.utils import coord_arg_to_coord

    # Load up
    hdu = fits.open(pargs.fits_file)
    icoord = coord_arg_to_coord(pargs.frb_coord)

    # Parse
    if pargs.vmnx is not None:
        tstr = pargs.vmnx.split(',')
        vmnx = (float(tstr[0]), float(tstr[1]))
    else:
        vmnx = (None, None)

    # Dummy FRB object
    FRB = frb.FRB('TMP', icoord, 0.)
    FRB.set_ee(1.0, 1.0, 0., 95.)

    fig = plt.figure(figsize=(7, 7))
    ffutils.set_mplrc()
    ffgal.sub_image(fig,
                    hdu,
                    FRB,
                    vmnx=vmnx,
                    cmap='gist_heat',
                    frb_clr='white',
                    imsize=pargs.imsize)  #img_center=HG190608.coord,

    # Layout and save
    plt.tight_layout(pad=0.2, h_pad=0.1, w_pad=0.1)
    plt.savefig(pargs.outfile, dpi=300)
    plt.close()
    print('Wrote {:s}'.format(pargs.outfile))
예제 #2
0
def generate(image,
             wcs,
             title,
             flip_ra=False,
             flip_dec=False,
             log_stretch=False,
             cutout=None,
             primary_coord=None,
             secondary_coord=None,
             third_coord=None,
             slit=None,
             vmnx=None,
             extra_text=None,
             outfile=None):
    """
    Basic method to generate a Finder chart figure

    Args:
        image (np.ndarray):
          Image for the finder
        wcs (astropy.wcs.WCS):
          WCS solution
        title (str):
          Title; typically the name of the primary source
        flip_ra (bool, default False):
            Flip the RA (x-axis). Useful for southern hemisphere finders.
        flip_dec (bool, default False):
            Flip the Dec (y-axis). Useful for southern hemisphere finders.
        log_stretch (bool, optional):
            Use a log stretch for the image display
        cutout (tuple, optional):
            SkyCoord (center coordinate) and Quantity (image angular size)
            for a cutout from the input image.
        primary_coord (astropy.coordinates.SkyCoord, optional):
          If provided, place a mark in red at this coordinate
        secondary_coord (astropy.coordinates.SkyCoord, optional):
          If provided, place a mark in cyan at this coordinate
          Assume it is an offset star (i.e. calculate offsets)
        third_coord (astropy.coordinates.SkyCoord, optional):
          If provided, place a mark in yellow at this coordinate
        slit (tuple, optional):
          If provided, places a rectangular slit with specified
          coordinates, width, length, and position angle on image (from North to East)
          [SkyCoords('21h44m25.255s',-40d54m00.1s', frame='icrs'), 1*u.arcsec, 10*u.arcsec, 20*u.deg]
        vmnx (tuple, optional):
          Used for scaling the image.  Otherwise, the image
          is analyzed for these values.
        extra_text : str
          Extra text to be added at the bottom of the Figure.
          e.g. `DSS r-filter`
        outfile (str, optional):
          Filename for the figure.  File type will be according
          to the extension

    Returns:
        matplotlib.pyplot.figure, matplotlib.pyplot.Axis

    """

    utils.set_mplrc()

    plt.clf()
    fig = plt.figure(figsize=(7, 8.5))
    # fig.set_size_inches(7.5,10.5)

    # Cutout?
    if cutout is not None:
        cutout_img = Cutout2D(image, cutout[0], cutout[1], wcs=wcs)
        # Overwrite
        wcs = cutout_img.wcs
        image = cutout_img.data

    # Axis
    ax = fig.add_axes([0.10, 0.20, 0.75, 0.5], projection=wcs)

    # Show
    if log_stretch:
        norm = mplnorm.ImageNormalize(stretch=LogStretch())
    else:
        norm = None
    cimg = ax.imshow(image, cmap='Greys', norm=norm)

    # Flip so RA increases to the left
    if flip_ra is True:
        ax.invert_xaxis()
    if flip_dec is True:
        ax.invert_yaxis()

    # N/E
    overlay = ax.get_coords_overlay('icrs')

    overlay['ra'].set_ticks(color='white')
    overlay['dec'].set_ticks(color='white')

    overlay['ra'].set_axislabel('Right Ascension')
    overlay['dec'].set_axislabel('Declination')

    overlay.grid(color='green', linestyle='solid', alpha=0.5)

    # Contrast
    if vmnx is None:
        mean, median, stddev = sigma_clipped_stats(
            image
        )  # Also set clipping level and number of iterations here if necessary
        #
        vmnx = (median - stddev, median + 2 * stddev
                )  # sky level - 1 sigma and +2 sigma above sky level
        print("Using vmnx = {} based on the image stats".format(vmnx))
    cimg.set_clim(vmnx[0], vmnx[1])

    # Add Primary
    if primary_coord is not None:
        c = SphericalCircle((primary_coord.ra, primary_coord.dec),
                            2 * units.arcsec,
                            transform=ax.get_transform('icrs'),
                            edgecolor='red',
                            facecolor='none')
        ax.add_patch(c)
        # Text
        jname = ltu.name_from_coord(primary_coord)
        ax.text(0.5,
                1.34,
                jname,
                fontsize=28,
                horizontalalignment='center',
                transform=ax.transAxes)

    # Secondary
    if secondary_coord is not None:
        c_S1 = SphericalCircle((secondary_coord.ra, secondary_coord.dec),
                               2 * units.arcsec,
                               transform=ax.get_transform('icrs'),
                               edgecolor='cyan',
                               facecolor='none')
        ax.add_patch(c_S1)
        # Text
        jname = ltu.name_from_coord(secondary_coord)
        ax.text(0.5,
                1.24,
                jname,
                fontsize=22,
                color='blue',
                horizontalalignment='center',
                transform=ax.transAxes)
        # Print offsets
        if primary_coord is not None:
            sep = primary_coord.separation(secondary_coord).to('arcsec')
            PA = primary_coord.position_angle(secondary_coord)
            # RA/DEC
            dec_off = np.cos(PA) * sep  # arcsec
            ra_off = np.sin(PA) * sep  # arcsec (East is *higher* RA)
            ax.text(
                0.5,
                1.22,
                'Offset from Ref. Star (cyan) to Target (red):\nRA(to targ) = {:.2f}  DEC(to targ) = {:.2f}'
                .format(-1 * ra_off.to('arcsec'), -1 * dec_off.to('arcsec')),
                fontsize=15,
                horizontalalignment='center',
                transform=ax.transAxes,
                color='blue',
                va='top')
    # Add tertiary
    if third_coord is not None:
        c = SphericalCircle((third_coord.ra, third_coord.dec),
                            2 * units.arcsec,
                            transform=ax.get_transform('icrs'),
                            edgecolor='yellow',
                            facecolor='none')
        ax.add_patch(c)

    # Slit
    if ((slit is not None) and (flag_photu is True)):
        # List of values - [coodinates, width, length, PA],
        # e.g. [SkyCoords('21h44m25.255s',-40d54m00.1s', frame='icrs'), 1*u.arcsec, 10*u.arcsec, 20*u.deg]
        slit_coords, width, length, pa = slit

        pa_deg = pa.to('deg').value

        aper = SkyRectangularAperture(
            positions=slit_coords, w=length, h=width, theta=pa
        )  # For theta=0, width goes North-South, which is slit length

        apermap = aper.to_pixel(wcs)

        apermap.plot(color='purple', lw=1)

        plt.text(0.5,
                 -0.1,
                 'Slit PA={} deg'.format(pa_deg),
                 color='purple',
                 fontsize=15,
                 ha='center',
                 va='top',
                 transform=ax.transAxes)

    if ((slit is not None) and (flag_photu is False)):
        raise IOError('Slit cannot be placed without photutils package')

    # Title
    ax.text(0.5,
            1.44,
            title,
            fontsize=32,
            horizontalalignment='center',
            transform=ax.transAxes)

    # Extra text
    if extra_text is not None:
        ax.text(-0.1,
                -0.25,
                extra_text,
                fontsize=20,
                horizontalalignment='left',
                transform=ax.transAxes)
    # Sources

    # Labels
    #ax.set_xlabel(r'\textbf{DEC (EAST direction)}')
    #ax.set_ylabel(r'\textbf{RA (SOUTH direction)}')

    if outfile is not None:
        plt.savefig(outfile)
        plt.close()
    else:
        plt.show()

    # Return
    return fig, ax
예제 #3
0
파일: galaxies.py 프로젝트: JayChittidi/FRB
def sub_sfms(ax_M, galaxies, clrs, markers):
    """
    Generate a SF vs. M* plot on top of PRIMUS galaxies

    Args:
        ax_M (matplotlib.axis):
        galaxies (list):
            List of FRB.galaxies.frbgalaxy.FRBGalaxy objects
        clrs (list):
            List of matplotlib colors
        markers (list):
            List of matplotlib marker types

    """
    utils.set_mplrc()

    # Load up
    primus_zcat = Table.read(
        os.path.join(primus_path, 'PRIMUS_2013_zcat_v1.fits.gz'))
    primus_mass = Table.read(
        os.path.join(primus_path, 'PRIMUS_2014_mass_v1.fits.gz'))

    gdz = (primus_zcat['Z'] > 0.2) & (primus_zcat['Z'] < 0.4)
    gd_mag = primus_zcat['SDSS_ABSMAG'][:, 0] != 0.

    good_mass = primus_mass['ISGOOD'] == 1

    # PRIMUS
    # Photometry
    gd_color = gdz & gd_mag
    u_r = primus_zcat['SDSS_ABSMAG'][gd_color,
                                     0] - primus_zcat['SDSS_ABSMAG'][gd_color,
                                                                     2]
    rmag = primus_zcat['SDSS_ABSMAG'][gd_color, 2]

    # Mass/SFR
    gd_msfr = good_mass & gdz
    mass = primus_mass['MASS'][gd_msfr]
    sfr = primus_mass['SFR'][gd_msfr]

    # Plot
    ms = 22.

    # Histogram
    xbins = 50
    ybins = 50
    counts, xedges, yedges = np.histogram2d(mass, sfr, bins=(xbins, ybins))
    #cm = plt.get_cmap('Reds')
    cm = plt.get_cmap('Greys')

    # SF
    mplt = ax_M.pcolormesh(xedges, yedges, counts.transpose(), cmap=cm)

    # Relation
    logm_star = np.linspace(8, 12)
    logsfr = lambda logm_star: -0.49 + 0.65 * (logm_star - 10) + 1.07 * (0.35 -
                                                                         0.1)
    ax_M.plot(logm_star, logsfr(logm_star), "k--",
              lw=3)  #, label="Moustakas et al 2013")

    # Galaxies
    for kk, galaxy in enumerate(galaxies):
        # M*
        if 'Mstar' in galaxy.derived.keys():
            logM, sig_logM = utils.log_me(galaxy.derived['Mstar'],
                                          galaxy.derived['Mstar_err'])
        elif 'Mstar_spec' in galaxy.derived.keys():
            logM, sig_logM = np.log10(galaxy.derived['Mstar_spec']), 0.3
        else:
            continue
        # SFR
        if 'SFR_nebular_err' in galaxy.derived.keys():
            logS, sig_logS = utils.log_me(galaxy.derived['SFR_nebular'],
                                          galaxy.derived['SFR_nebular_err'])
        else:
            logS, sig_logS = utils.log_me(galaxy.derived['SFR_nebular'],
                                          0.3 * galaxy.derived['SFR_nebular'])
        # Plot
        ax_M.errorbar([logM], [logS],
                      xerr=sig_logM,
                      yerr=sig_logS,
                      color=clrs[kk],
                      marker=markers[kk],
                      markersize="12",
                      capsize=3,
                      label=galaxy.name)
        if sig_logS is None:
            # Down arrow
            plt.arrow(logM,
                      logS,
                      0.,
                      -0.2,
                      fc=clrs[kk],
                      ec=clrs[kk],
                      head_width=0.02 * 4,
                      head_length=0.05 * 2)

    ax_M.annotate(r"\textbf{Star forming}", (8.5, 0.8), fontsize=13.)
    ax_M.annotate(r"\textbf{Quiescent}", (11, -0.9), fontsize=13.)
    ax_M.set_xlabel("$\log \, (M_*/M_\odot)$")
    ax_M.set_ylabel("$\log \, SFR (M_\odot$/yr)")
    ax_M.legend(loc='lower left')
    ax_M.set_xlim(7.5, 11.8)
    ax_M.set_ylim(-2.5, 1.2)
예제 #4
0
def generate(image,
             wcs,
             title,
             log_stretch=False,
             cutout=None,
             primary_coord=None,
             secondary_coord=None,
             third_coord=None,
             vmnx=None,
             outfile=None):
    """
    Basic method to generate a Finder chart figure

    Args:
        image (np.ndarray):
          Image for the finder
        wcs (astropy.wcs.WCS):
          WCS solution
        title (str):
          TItle; typically the name of the primry source
        log_stretch (bool, optional):
            Use a log stretch for the image display
        cutout (tuple, optional):
            SkyCoord (center coordinate) and Quantity (image angular size)
            for a cutout from the input image.
        primary_coord (astropy.coordinates.SkyCoord, optional):
          If provided, place a mark in red at this coordinate
        secondary_coord (astropy.coordinates.SkyCoord, optional):
          If provided, place a mark in cyan at this coordinate
          Assume it is an offset star (i.e. calculate offsets)
        third_coord (astropy.coordinates.SkyCoord, optional):
          If provided, place a mark in yellow at this coordinate
        vmnx (tuple, optional):
          Used for scaling the image.  Otherwise, the image
          is analyzed for these values.
        outfile (str, optional):
          Filename for the figure.  File type will be according
          to the extension

    Returns:
        matplotlib.pyplot.figure, matplotlib.pyplot.Axis

    """

    utils.set_mplrc()

    plt.clf()
    fig = plt.figure(dpi=600)
    fig.set_size_inches(7.5, 10.5)

    # Cutout?
    if cutout is not None:
        cutout_img = Cutout2D(image, cutout[0], cutout[1], wcs=wcs)
        # Overwrite
        wcs = cutout_img.wcs
        image = cutout_img.data

    # Axis
    ax = fig.add_axes([0.10, 0.20, 0.80, 0.5], projection=wcs)

    # Show
    if log_stretch:
        norm = mplnorm.ImageNormalize(stretch=LogStretch())
    else:
        norm = None
    cimg = ax.imshow(image, cmap='Greys', norm=norm)

    # Flip so RA increases to the left
    ax.invert_xaxis()

    # N/E
    overlay = ax.get_coords_overlay('icrs')

    overlay['ra'].set_ticks(color='white')
    overlay['dec'].set_ticks(color='white')

    overlay['ra'].set_axislabel('Right Ascension')
    overlay['dec'].set_axislabel('Declination')

    overlay.grid(color='green', linestyle='solid', alpha=0.5)

    # Contrast
    if vmnx is None:
        mean, median, stddev = sigma_clipped_stats(
            image
        )  # Also set clipping level and number of iterations here if necessary
        #
        vmnx = (median - stddev, median + 2 * stddev
                )  # sky level - 1 sigma and +2 sigma above sky level
        print("Using vmnx = {} based on the image stats".format(vmnx))
    cimg.set_clim(vmnx[0], vmnx[1])

    # Add Primary
    if primary_coord is not None:
        c = SphericalCircle((primary_coord.ra, primary_coord.dec),
                            2 * units.arcsec,
                            transform=ax.get_transform('icrs'),
                            edgecolor='red',
                            facecolor='none')
        ax.add_patch(c)
        # Text
        jname = ltu.name_from_coord(primary_coord)
        ax.text(0.5,
                1.34,
                jname,
                fontsize=28,
                horizontalalignment='center',
                transform=ax.transAxes)

    # Secondary
    if secondary_coord is not None:
        c_S1 = SphericalCircle((secondary_coord.ra, secondary_coord.dec),
                               2 * units.arcsec,
                               transform=ax.get_transform('icrs'),
                               edgecolor='cyan',
                               facecolor='none')
        ax.add_patch(c_S1)
        # Text
        jname = ltu.name_from_coord(secondary_coord)
        ax.text(0.5,
                1.24,
                jname,
                fontsize=22,
                color='blue',
                horizontalalignment='center',
                transform=ax.transAxes)
        # Print offsets
        if primary_coord is not None:
            sep = primary_coord.separation(secondary_coord).to('arcsec')
            PA = primary_coord.position_angle(secondary_coord)
            # RA/DEC
            dec_off = np.cos(PA) * sep  # arcsec
            ra_off = np.sin(PA) * sep  # arcsec (East is *higher* RA)
            ax.text(0.5,
                    1.14,
                    'RA(to targ) = {:.2f}  DEC(to targ) = {:.2f}'.format(
                        -1 * ra_off.to('arcsec'), -1 * dec_off.to('arcsec')),
                    fontsize=18,
                    horizontalalignment='center',
                    transform=ax.transAxes)
    # Add tertiary
    if third_coord is not None:
        c = SphericalCircle((third_coord.ra, third_coord.dec),
                            2 * units.arcsec,
                            transform=ax.get_transform('icrs'),
                            edgecolor='yellow',
                            facecolor='none')
        ax.add_patch(c)
    # Slit?
    '''
    if slit is not None:
        r = Rectangle((primary_coord.ra.value, primary_coord.dec.value),
                      slit[0]/3600., slit[1]/3600., angle=360-slit[2],
                      transform=ax.get_transform('icrs'),
                      facecolor='none', edgecolor='red')
        ax.add_patch(r)
    '''
    # Title
    ax.text(0.5,
            1.44,
            title,
            fontsize=32,
            horizontalalignment='center',
            transform=ax.transAxes)

    # Sources
    # Labels
    #ax.set_xlabel(r'\textbf{DEC (EAST direction)}')
    #ax.set_ylabel(r'\textbf{RA (SOUTH direction)}')

    if outfile is not None:
        plt.savefig(outfile)
        plt.close()
    else:
        plt.show()

    # Return
    return fig, ax
예제 #5
0
def fig_cosmic(frbs, clrs=None, outfile=None, multi_model=False, no_curves=False,
               widen=False, show_nuisance=False, ax=None,
               show_sigmaDM=False, cl=(16,84), beta=3., gold_only=True, gold_frbs=None):
    """

    Args:
        frbs (list):
            list of FRB objects
        clrs (list, optional):
        outfile (str, optional):
        multi_model (deprecated):
        no_curves (bool, optional):
            If True, just show the data
        widen (bool, optional):
            If True, make the plot wide
        show_nuisance (bool, optional):
            if True, add a label giving the Nuiscance value
        show_sigmaDM (bool, optional):
            If True, show a model estimate of the scatter in the DM relation
        cl (tuple, optional):
            Confidence limits for the scatter
        beta (float, optional):
            Parameter to the DM scatter estimation
        gold_only (bool, optional):
            If True, limit to the gold standard sample
        gold_frbs (list, optional):
            List of gold standard FRBs
        ax (matplotlib.Axis, optional):
            Use this axis instead of creating one

    Returns:

    """
    # Init
    if gold_frbs is None:
        gold_frbs = cosmic.gold_frbs

    # Plotting
    ff_utils.set_mplrc()

    bias_clr = 'darkgray'

    # Start the plot
    if ax is None:
        if widen:
            fig = plt.figure(figsize=(12, 8))
        else:
            fig = plt.figure(figsize=(8, 8))
        plt.clf()
        ax = plt.gca()

    # DM_cosmic from cosmology
    zmax = 0.75
    DM_cosmic, zeval = frb_igm.average_DM(zmax, cumul=True)
    DMc_spl = IUS(zeval, DM_cosmic)
    if not no_curves:
        #ax.plot(zeval, DM_cosmic, 'k-', label=r'DM$_{\rm cosmic} (z) \;\; [\rm Planck15]$')
        ax.plot(zeval, DM_cosmic, 'k-', label='Planck15')

    if multi_model:
        # Change Omega_b
        cosmo_highOb = FlatLambdaCDM(Ob0=Planck15.Ob0*1.2, Om0=Planck15.Om0, H0=Planck15.H0)
        DM_cosmic_high, zeval_high = frb_igm.average_DM(zmax, cumul=True, cosmo=cosmo_highOb)
        ax.plot(zeval_high, DM_cosmic_high, '--', color='gray', label=r'DM$_{\rm cosmic} (z) \;\; [1.2 \times \Omega_b]$')
        # Change H0
        cosmo_lowH0 = FlatLambdaCDM(Ob0=Planck15.Ob0, Om0=Planck15.Om0, H0=Planck15.H0/1.2)
        DM_cosmic_lowH0, zeval_lowH0 = frb_igm.average_DM(zmax, cumul=True, cosmo=cosmo_lowH0)
        ax.plot(zeval_lowH0, DM_cosmic_lowH0, ':', color='gray', label=r'DM$_{\rm cosmic} (z) \;\; [H_0/1.2]$')

    if show_sigmaDM:
        #f_C0 = frb_cosmology.build_C0_spline()
        f_C0_3 = cosmic.grab_C0_spline(beta=3.)
        # Updated
        F = 0.2
        nstep=50
        sigma_DM = F * zeval**(-0.5) #* DM_cosmic.value
        sub_sigma_DM = sigma_DM[::nstep]
        sub_z = zeval[::nstep]
        sub_DM = DM_cosmic.value[::nstep]
        # Loop me
        sigmas, C0s, sigma_lo, sigma_hi = [], [], [], []
        for kk, isigma in enumerate(sub_sigma_DM):
            #res = frb_cosmology.minimize_scalar(frb_cosmology.deviate2, args=(f_C0, isigma))
            #sigmas.append(res.x)
            sigmas.append(isigma)
            C0s.append(float(f_C0_3(isigma)))
            # PDF
            PDF = cosmic.DMcosmic_PDF(cosmic.Delta_values, C0s[-1], sigma=sigmas[-1], beta=beta)
            cumsum = np.cumsum(PDF) / np.sum(PDF)
            #if sub_DM[kk] > 200.:
            #    embed(header='131')
            # DO it
            DM = cosmic.Delta_values * sub_DM[kk]
            sigma_lo.append(DM[np.argmin(np.abs(cumsum-cl[0]/100))])
            sigma_hi.append(DM[np.argmin(np.abs(cumsum-cl[1]/100))])
        # Plot
        ax.fill_between(sub_z, sigma_lo, sigma_hi, # DM_cosmic.value-sigma_DM, DM_cosmic.value+sigma_DM,
                        color='gray', alpha=0.3)

    # Do each FRB
    DM_subs = []
    for ifrb in frbs:
        DM_sub = ifrb.DM - ifrb.DMISM
        DM_subs.append(DM_sub.value)
    DM_subs = np.array(DM_subs)

    # chi2
    DMs_MW_host = np.linspace(30., 100., 100)
    zs = np.array([ifrb.z for ifrb in frbs])
    DM_theory = DMc_spl(zs)

    chi2 = np.zeros_like(DMs_MW_host)
    for kk,DM_MW_host in enumerate(DMs_MW_host):
        chi2[kk] = np.sum(((DM_subs-DM_MW_host)-DM_theory)**2)

    imin = np.argmin(chi2)
    DM_MW_host_chisq = DMs_MW_host[imin]
    print("DM_nuisance = {}".format(DM_MW_host))

    # MW + Host term
    def DM_MW_host(z, min_chisq=False):
        if min_chisq:
            return DM_MW_host_chisq
        else:
            return 50. + 50./(1+z)

    # Gold FRBs
    for kk,ifrb in enumerate(frbs):
        if ifrb.frb_name not in gold_frbs:
            continue
        if clrs is not None:
            clr = clrs[kk]
        else:
            clr = None
        ax.scatter([ifrb.z], [DM_subs[kk]-DM_MW_host(ifrb.z)],
                        label=ifrb.frb_name, marker='s', s=90, color=clr)

    # ################################
    # Other FRBs
    s_other = 90

    if not gold_only:
        labeled = False
        for kk, ifrb in enumerate(frbs):
            if ifrb.frb_name in gold_frbs:
                continue
            if not labeled:
                lbl = "Others"
                labeled = True
            else:
                lbl = None
            ax.scatter([ifrb.z], [ifrb.DM.value -
                                      ifrb.DMISM.value - DM_MW_host(ifrb.z)],
                   label=lbl, marker='o', s=s_other, color=bias_clr)


    legend = ax.legend(loc='upper left', scatterpoints=1, borderpad=0.2,
                        handletextpad=0.3, fontsize=19)
    ax.set_xlim(0, 0.7)
    ax.set_ylim(0, 1000.)
    #ax.set_xlabel(r'$z_{\rm FRB}$', fontname='DejaVu Sans')
    ax.set_xlabel(r'$z_{\rm FRB}$', fontname='DejaVu Sans')
    ax.set_ylabel(r'$\rm DM_{cosmic} \; (pc \, cm^{-3})$', fontname='DejaVu Sans')

    #
    if show_nuisance:
        ax.text(0.05, 0.60, r'$\rm DM_{MW,halo} + DM_{host} = $'+' {:02d} pc '.format(int(DM_MW_host))+r'cm$^{-3}$',
            transform=ax.transAxes, fontsize=23, ha='left', color='black')

    ff_utils.set_fontsize(ax, 23.)

    # Layout and save
    if outfile is not None:
        plt.tight_layout(pad=0.2,h_pad=0.1,w_pad=0.1)
        plt.savefig(outfile, dpi=400)
        print('Wrote {:s}'.format(outfile))
        plt.close()
    else:
        return ax