Ejemplo n.º 1
0
    def test_R_approximation(self):
        """check that the R approximation works for low frequencies"""
        frequencies = np.logspace(-6, -2, 10000) * u.Hz

        exact = psd.power_spectral_density(frequencies, approximate_R=True)
        approx = psd.power_spectral_density(frequencies, approximate_R=False)

        self.assertTrue(np.allclose(exact, approx))
Ejemplo n.º 2
0
    def test_custom_confusion_noise(self):
        """ check that using custom confusion noise works """
        frequencies = np.logspace(-6, 0, 100) * u.Hz

        regular = psd.power_spectral_density(frequencies, confusion_noise=None)
        custom = psd.power_spectral_density(frequencies, confusion_noise=lambda f, t: 0 * u.Hz**(-1))

        self.assertTrue(np.allclose(regular, custom))

        # same for TianQin
        regular = psd.power_spectral_density(frequencies, instrument="TianQin", confusion_noise=None)
        custom = psd.power_spectral_density(frequencies, instrument="TianQin",
                                            confusion_noise=lambda f, t: 0 * u.Hz**(-1))

        self.assertTrue(np.allclose(regular, custom))
Ejemplo n.º 3
0
    def test_mission_length_effect(self):
        """check that increasing the mission length isn't changing
        anything far from confusion noise"""
        frequencies = np.logspace(-6, 0, 100) * u.Hz

        # compute same curve with various mission length
        smol = psd.power_spectral_density(frequencies, t_obs=0.5 * u.yr)
        teeny = psd.power_spectral_density(frequencies, t_obs=1.0 * u.yr)
        little = psd.power_spectral_density(frequencies, t_obs=2.0 * u.yr)
        regular = psd.power_spectral_density(frequencies, t_obs=4.5 * u.yr)
        looonngg = psd.power_spectral_density(frequencies, t_obs=10.0 * u.yr)
        noises = [smol, teeny, little, regular, looonngg]

        # ensure that a shorter mission length never decreases the noise
        for noise in noises:
            above = noise > regular
            close = np.isclose(noise, regular, atol=1e-39)
            self.assertTrue(np.logical_or(above, close).all())
Ejemplo n.º 4
0
    def test_alternate_instruments(self):
        """check that changing instruments doesn't break things"""
        frequencies = np.logspace(-6, 0, 100) * u.Hz

        tq = psd.power_spectral_density(frequencies, instrument="TianQin")

        def custom_instrument(f, t_obs, L, approximate_R, confusion_noise):
            return psd.tianqin_psd(f, L * 2, t_obs, approximate_R, confusion_noise)

        custom = psd.power_spectral_density(frequencies, instrument="custom",
                                            custom_psd=custom_instrument,
                                            L=np.sqrt(3) * 1e5 * u.km,
                                            confusion_noise="huang20",
                                            t_obs=5 * u.yr)

        self.assertTrue(np.all(custom <= tq))

        all_good = True
        try:
            psd.power_spectral_density(frequencies, instrument="nonsense")
        except ValueError:
            all_good = False
        self.assertFalse(all_good)
Ejemplo n.º 5
0
def snr_ecc_stationary(m_c,
                       f_orb,
                       ecc,
                       dist,
                       t_obs,
                       harmonics_required,
                       position=None,
                       polarisation=None,
                       inclination=None,
                       interpolated_g=None,
                       interpolated_sc=None,
                       ret_max_snr_harmonic=False,
                       ret_snr2_by_harmonic=False,
                       instrument="LISA",
                       custom_psd=None):
    """Computes SNR for eccentric and stationary sources

    Parameters
    ----------
    m_c : `float/array`
        Chirp mass

    f_orb : `float/array`
        Orbital frequency

    ecc : `float/array`
        Eccentricity

    dist : `float/array`
        Distance to the source

    t_obs : `float`
        Total duration of the observation

    harmonics_required : `integer`
        Maximum integer harmonic to compute

    position : `SkyCoord/array`, optional
        Sky position of source. Must be specified using Astropy's :class:`astropy.coordinates.SkyCoord` class.

    polarisation : `float/array`, optional
        GW polarisation angle of the source. Must have astropy angular units.

    inclination : `float/array`, optional
        Inclination of the source. Must have astropy angular units.

    interpolated_g : `function`
        A function returned by :class:`scipy.interpolate.interp2d` that computes g(n,e) from Peters (1964).
        The code assumes that the function returns the output sorted as with the interp2d returned functions
        (and thus unsorts). Default is None and uses exact g(n,e) in this case.

    interpolated_sc : `function`
        A function returned by :class:`scipy.interpolate.interp1d` that computes the LISA sensitivity curve.
        Default is None and uses exact values. Note: take care to ensure that your interpolated function has
        the same LISA observation time as ``t_obs`` and uses the same instrument.

    ret_max_snr_harmonic : `boolean`
        Whether to return (in addition to the snr), the harmonic with the maximum SNR

    ret_snr2_by_harmonic : `boolean`
        Whether to return the SNR^2 in each individual harmonic rather than the total.
        The total can be retrieving by summing and then taking the square root.

    instrument : `{{ 'LISA', 'TianQin', 'custom' }}`
        Instrument to observe with. If 'custom' then ``custom_psd`` must be supplied.

    custom_psd : `function`
        Custom function for computing the PSD. Must take the same arguments as :meth:`legwork.psd.lisa_psd`
        even if it ignores some.

    Returns
    -------
    snr : `float/array`
        SNR for each binary

    max_snr_harmonic : `int/array`
        harmonic with maximum SNR for each binary (only returned if ``ret_max_snr_harmonic=True``)
    """
    # define range of harmonics
    n_range = np.arange(1, harmonics_required + 1).astype(int)

    # calculate source signal
    h_0_ecc_n_2 = strain.h_0_n(m_c=m_c,
                               f_orb=f_orb,
                               ecc=ecc,
                               n=n_range,
                               dist=dist,
                               position=position,
                               polarisation=polarisation,
                               inclination=inclination,
                               interpolated_g=interpolated_g)**2

    # reshape the output since only one timestep
    h_0_ecc_n_2 = h_0_ecc_n_2.reshape(len(m_c), harmonics_required)
    h_f_src_ecc_2 = h_0_ecc_n_2 * t_obs

    # calculate harmonic frequencies and noise
    f_n = n_range[np.newaxis, :] * f_orb[:, np.newaxis]
    if interpolated_sc is not None:
        h_f_lisa_n_2 = interpolated_sc(f_n.flatten())
        h_f_lisa_n_2 = h_f_lisa_n_2.reshape(f_n.shape)
    else:
        h_f_lisa_n_2 = psd.power_spectral_density(f=f_n,
                                                  t_obs=t_obs,
                                                  instrument=instrument,
                                                  custom_psd=custom_psd)

    snr_n_2 = (h_f_src_ecc_2 / h_f_lisa_n_2).decompose()

    if ret_snr2_by_harmonic:
        return snr_n_2.decompose().value

    # calculate the signal-to-noise ratio
    snr = (np.sum(snr_n_2, axis=1))**0.5

    if ret_max_snr_harmonic:
        max_snr_harmonic = np.argmax(snr_n_2, axis=1) + 1
        return snr.decompose().value, max_snr_harmonic
    else:
        return snr.decompose().value
Ejemplo n.º 6
0
def snr_ecc_evolving(m_1,
                     m_2,
                     f_orb_i,
                     dist,
                     ecc,
                     harmonics_required,
                     t_obs,
                     n_step,
                     position=None,
                     polarisation=None,
                     inclination=None,
                     t_merge=None,
                     interpolated_g=None,
                     interpolated_sc=None,
                     n_proc=1,
                     ret_max_snr_harmonic=False,
                     ret_snr2_by_harmonic=False,
                     instrument="LISA",
                     custom_psd=None):
    """Computes SNR for eccentric and evolving sources.

    Note that this function will not work for exactly circular (ecc = 0.0)
    binaries.

    Parameters
    ----------
    m_1 : `float/array`
        Primary mass

    m_2 : `float/array`
        Secondary mass

    f_orb_i : `float/array`
        Initial orbital frequency

    dist : `float/array`
        Distance to the source

    ecc : `float/array`
        Eccentricity

    harmonics_required : `int`
        Maximum integer harmonic to compute

    t_obs : `float`
        Total duration of the observation

    position : `SkyCoord/array`, optional
        Sky position of source. Must be specified using Astropy's :class:`astropy.coordinates.SkyCoord` class.

    polarisation : `float/array`, optional
        GW polarisation angle of the source. Must have astropy angular units.

    inclination : `float/array`, optional
        Inclination of the source. Must have astropy angular units.

    n_step : `int`
        Number of time steps during observation duration

    t_merge : `float/array`
        Time until merger

    interpolated_g : `function`
        A function returned by :class:`scipy.interpolate.interp2d` that computes g(n,e) from Peters (1964).
        The code assumes that the function returns the output sorted as with the interp2d returned functions
        (and thus unsorts). Default is None and uses exact g(n,e) in this case.

    interpolated_sc : `function`
        A function returned by :class:`scipy.interpolate.interp1d` that computes the LISA sensitivity curve.
        Default is None and uses exact values. Note: take care to ensure that your interpolated function has
        the same LISA observation time as ``t_obs`` and uses the same instrument.

    n_proc : `int`
        Number of processors to split eccentricity evolution over, where
        the default is n_proc=1

    ret_max_snr_harmonic : `boolean`
        Whether to return (in addition to the snr), the harmonic with the maximum SNR

    ret_snr2_by_harmonic : `boolean`
        Whether to return the SNR^2 in each individual harmonic rather than the total.
        The total can be retrieving by summing and then taking the square root.

    instrument : `{{ 'LISA', 'TianQin', 'custom' }}`
        Instrument to observe with. If 'custom' then ``custom_psd`` must be supplied.

    custom_psd : `function`
        Custom function for computing the PSD. Must take the same arguments as :meth:`legwork.psd.lisa_psd`
        even if it ignores some.

    Returns
    -------
    snr : `float/array`
        SNR for each binary

    max_snr_harmonic : `int/array`
        harmonic with maximum SNR for each binary (only returned if
        ``ret_max_snr_harmonic=True``)
    """
    m_c = utils.chirp_mass(m_1=m_1, m_2=m_2)

    # calculate minimum of observation time and merger time
    if t_merge is None:
        t_merge = evol.get_t_merge_ecc(m_1=m_1,
                                       m_2=m_2,
                                       f_orb_i=f_orb_i,
                                       ecc_i=ecc)

    t_before = 0.1 * u.yr

    t_evol = np.minimum(t_merge - t_before, t_obs).to(u.s)
    # get eccentricity and f_orb evolutions
    e_evol, f_orb_evol = evol.evol_ecc(ecc_i=ecc,
                                       t_evol=t_evol,
                                       n_step=n_step,
                                       m_1=m_1,
                                       m_2=m_2,
                                       f_orb_i=f_orb_i,
                                       n_proc=n_proc,
                                       t_before=t_before,
                                       t_merge=t_merge)

    maxes = np.where(np.logical_and(e_evol == 0.0, f_orb_evol == 1e2 * u.Hz),
                     -1 * u.Hz, f_orb_evol).max(axis=1)
    for source in range(len(f_orb_evol)):
        f_orb_evol[source][f_orb_evol[source] == 1e2 * u.Hz] = maxes[source]

    # create harmonics list and multiply for nth frequency evolution
    harms = np.arange(1, harmonics_required + 1).astype(int)
    f_n_evol = harms[np.newaxis, np.newaxis, :] * f_orb_evol[..., np.newaxis]

    # calculate the characteristic strain
    h_c_n_2 = strain.h_c_n(m_c=m_c,
                           f_orb=f_orb_evol,
                           ecc=e_evol,
                           n=harms,
                           dist=dist,
                           position=position,
                           polarisation=polarisation,
                           inclination=inclination,
                           interpolated_g=interpolated_g)**2

    # calculate the characteristic noise power
    if interpolated_sc is not None:
        h_f_lisa = interpolated_sc(f_n_evol.flatten())
    else:
        h_f_lisa = psd.power_spectral_density(f=f_n_evol.flatten(),
                                              t_obs=t_obs,
                                              instrument=instrument,
                                              custom_psd=custom_psd)
    h_f_lisa = h_f_lisa.reshape(f_n_evol.shape)
    h_c_lisa_2 = f_n_evol**2 * h_f_lisa

    snr_evol = h_c_n_2 / h_c_lisa_2

    # integrate, sum and square root to get SNR
    snr_n_2 = np.trapz(y=snr_evol, x=f_n_evol, axis=1)

    if ret_snr2_by_harmonic:
        return snr_n_2.decompose().value

    snr_2 = snr_n_2.sum(axis=1)
    snr = np.sqrt(snr_2)

    if ret_max_snr_harmonic:
        max_snr_harmonic = np.argmax(snr_n_2, axis=1) + 1
        return snr.decompose().value, max_snr_harmonic
    else:
        return snr.decompose().value
Ejemplo n.º 7
0
def snr_circ_evolving(m_1,
                      m_2,
                      f_orb_i,
                      dist,
                      t_obs,
                      n_step,
                      position=None,
                      polarisation=None,
                      inclination=None,
                      t_merge=None,
                      interpolated_g=None,
                      interpolated_sc=None,
                      instrument="LISA",
                      custom_psd=None):
    """Computes SNR for circular and stationary sources

    Parameters
    ----------
    m_1 : `float/array`
        Primary mass

    m_2 : `float/array`
        Secondary mass

    f_orb_i : `float/array`
        Initial orbital frequency

    dist : `float/array`
        Distance to the source

    t_obs : `float`
        Total duration of the observation

    n_step : `int`
        Number of time steps during observation duration

    position : `SkyCoord/array`, optional
        Sky position of source. Must be specified using Astropy's :class:`astropy.coordinates.SkyCoord` class.

    polarisation : `float/array`, optional
        GW polarisation angle of the source. Must have astropy angular units.

    inclination : `float/array`, optional
        Inclination of the source. Must have astropy angular units.

    t_merge : `float/array`
        Time until merger

    interpolated_g : `function`
        A function returned by :class:`scipy.interpolate.interp2d` that computes g(n,e) from Peters (1964).
        The code assumes that the function returns the output sorted as with the interp2d returned functions
        (and thus unsorts). Default is None and uses exact g(n,e) in this case.

    interpolated_sc : `function`
        A function returned by :class:`scipy.interpolate.interp1d` that computes the LISA sensitivity curve.
        Default is None and uses exact values. Note: take care to ensure that your interpolated function has
        the same LISA observation time as ``t_obs`` and uses the same instrument.

    instrument : `{{ 'LISA', 'TianQin', 'custom' }}`
        Instrument to observe with. If 'custom' then ``custom_psd`` must be supplied.

    custom_psd : `function`
        Custom function for computing the PSD. Must take the same arguments as :meth:`legwork.psd.lisa_psd`
        even if it ignores some.

    Returns
    -------
    sn : `float/array`
        SNR for each binary
    """
    m_c = utils.chirp_mass(m_1=m_1, m_2=m_2)

    # calculate minimum of observation time and merger time
    if t_merge is None:
        t_merge = evol.get_t_merge_circ(m_1=m_1, m_2=m_2, f_orb_i=f_orb_i)
    t_evol = np.minimum(t_merge - (1 * u.s), t_obs)

    # get f_orb evolution
    f_orb_evol = evol.evol_circ(t_evol=t_evol,
                                n_step=n_step,
                                m_1=m_1,
                                m_2=m_2,
                                f_orb_i=f_orb_i)

    maxes = np.where(f_orb_evol == 1e2 * u.Hz, -1 * u.Hz,
                     f_orb_evol).max(axis=1)
    for source in range(len(f_orb_evol)):
        f_orb_evol[source][f_orb_evol[source] == 1e2 * u.Hz] = maxes[source]

    # calculate the characteristic power
    h_c_n_2 = strain.h_c_n(m_c=m_c,
                           f_orb=f_orb_evol,
                           ecc=np.zeros_like(f_orb_evol).value,
                           n=2,
                           dist=dist,
                           interpolated_g=interpolated_g)**2
    h_c_n_2 = h_c_n_2.reshape(len(m_c), n_step)

    # calculate the characteristic noise power
    if interpolated_sc is not None:
        h_f_lisa_2 = interpolated_sc(2 * f_orb_evol.flatten())
        h_f_lisa_2 = h_f_lisa_2.reshape(f_orb_evol.shape)
    else:
        h_f_lisa_2 = psd.power_spectral_density(f=2 * f_orb_evol,
                                                t_obs=t_obs,
                                                instrument=instrument,
                                                custom_psd=custom_psd)
    h_c_lisa_2 = (2 * f_orb_evol)**2 * h_f_lisa_2

    snr = np.trapz(y=h_c_n_2 / h_c_lisa_2, x=2 * f_orb_evol, axis=1)**0.5

    return snr.decompose().value
Ejemplo n.º 8
0
def snr_circ_stationary(m_c,
                        f_orb,
                        dist,
                        t_obs,
                        position=None,
                        polarisation=None,
                        inclination=None,
                        interpolated_g=None,
                        interpolated_sc=None,
                        instrument="LISA",
                        custom_psd=None):
    """Computes SNR for circular and stationary sources

    Parameters
    ----------
    m_c : `float/array`
        Chirp mass

    f_orb : `float/array`
        Orbital frequency

    dist : `float/array`
        Distance to the source

    t_obs : `float`
        Total duration of the observation

    position : `SkyCoord/array`, optional
        Sky position of source. Must be specified using Astropy's :class:`astropy.coordinates.SkyCoord` class.

    polarisation : `float/array`, optional
        GW polarisation angle of the source. Must have astropy angular units.

    inclination : `float/array`, optional
        Inclination of the source. Must have astropy angular units.

    interpolated_g : `function`
        A function returned by :class:`scipy.interpolate.interp2d` that computes g(n,e) from Peters (1964).
        The code assumes that the function returns the output sorted as with the interp2d returned functions
        (and thus unsorts). Default is None and uses exact g(n,e) in this case.

    interpolated_sc : `function`
        A function returned by :class:`scipy.interpolate.interp1d` that computes the LISA sensitivity curve.
        Default is None and uses exact values. Note: take care to ensure that your interpolated function has
        the same LISA observation time as ``t_obs`` and uses the same instrument.

    instrument : `{{ 'LISA', 'TianQin', 'custom' }}`
        Instrument to observe with. If 'custom' then ``custom_psd`` must be supplied.

    custom_psd : `function`
        Custom function for computing the PSD. Must take the same arguments as :meth:`legwork.psd.lisa_psd`
        even if it ignores some.

    Returns
    -------
    snr : `float/array`
        SNR for each binary
    """

    # only need to compute n=2 harmonic for circular
    h_0_circ_2 = strain.h_0_n(m_c=m_c,
                              f_orb=f_orb,
                              ecc=np.zeros_like(f_orb).value,
                              n=2,
                              dist=dist,
                              position=position,
                              polarisation=polarisation,
                              inclination=inclination,
                              interpolated_g=interpolated_g).flatten()**2

    h_f_src_circ_2 = h_0_circ_2 * t_obs
    if interpolated_sc is not None:
        h_f_lisa_2 = interpolated_sc(2 * f_orb)
    else:
        h_f_lisa_2 = psd.power_spectral_density(f=2 * f_orb,
                                                t_obs=t_obs,
                                                instrument=instrument,
                                                custom_psd=custom_psd)
    snr = (h_f_src_circ_2 / h_f_lisa_2)**0.5

    return snr.decompose().value
Ejemplo n.º 9
0
def plot_sources_on_sc_ecc_stat(f_dom,
                                snr,
                                weights=None,
                                snr_cutoff=0,
                                t_obs="auto",
                                instrument="LISA",
                                custom_psd=None,
                                L="auto",
                                approximate_R=False,
                                confusion_noise="auto",
                                fig=None,
                                ax=None,
                                show=True,
                                **kwargs):
    """Overlay eccentric/stationary sources on the LISA sensitivity curve.

    Each source is plotted at its max snr harmonic frequency such that that its height above the curve is
    equal to it signal-to-noise ratio.

    Parameters
    ----------
    f_dom : `float/array`
        Dominant harmonic frequency (f_orb * n_dom where n_dom is the harmonic with the maximum snr)

    snr : `float/array`
        Signal-to-noise ratio

    weights : `float/array`, optional, default=None
        Statistical weights for each source, default is equal weights

    snr_cutoff : `float`
        SNR above which to plot binaries (default is 0 such that all sources are plotted)

    instrument: {{ `LISA`, `TianQin`, `custom` }}
        Instrument to use. LISA is used by default. Choosing `custom` uses ``custom_psd`` to compute PSD.

    custom_psd : `function`
        Custom function for computing the PSD. Must take the same arguments as :meth:`legwork.psd.lisa_psd`
        even if it ignores some.

    t_obs : `float`
        Observation time (default auto)

    L : `float`
        Arm length in metres

    approximate_R : `boolean`
        Whether to approximate the response function (default: no)

    confusion_noise : `various`
        Galactic confusion noise. Acceptable inputs are either one of the values listed in
        :meth:`legwork.psd.get_confusion_noise`, "auto" (automatically selects confusion noise based on
        `instrument` - 'robson19' if LISA and 'huang20' if TianQin), or a custom function that gives the
        confusion noise at each frequency for a given mission length where it would be called by running
        `noise(f, t_obs)` and return a value with units of inverse Hertz

    fig: `matplotlib Figure`
        A figure on which to plot the distribution. Both `ax` and `fig` must be supplied for either to be used

    ax: `matplotlib Axis`
        An axis on which to plot the distribution. Both `ax` and `fig` must be supplied for either to be used

    show : `boolean`
        Whether to immediately show the plot or only return the Figure and Axis

    **kwargs : `various`
        This function is a wrapper on :func:`legwork.visualisation.plot_2D_dist` and each kwarg is passed
        directly to this function. For example, you can write `disttype="kde"` for a kde density plot
        instead of a scatter plot.

    Returns
    -------
    fig : `matplotlib Figure`
        The figure on which the distribution is plotted

    ax : `matplotlib Axis`
        The axis on which the distribution is plotted
    """
    # create figure if it wasn't provided
    if fig is None or ax is None:
        fig, ax = plot_sensitivity_curve(show=False,
                                         t_obs=t_obs,
                                         instrument=instrument,
                                         L=L,
                                         custom_psd=custom_psd,
                                         approximate_R=approximate_R,
                                         confusion_noise=confusion_noise)

    # work out which binaries are above the cutoff
    detectable = snr > snr_cutoff
    if not detectable.any():
        print("ERROR: There are no binaries above provided `snr_cutoff`")
        return fig, ax

    # calculate asd that makes it so height above curve is snr
    asd = snr[detectable] * np.sqrt(
        psd.power_spectral_density(f_dom[detectable]))

    # plot either a scatter or density plot of the detectable binaries
    ylims = ax.get_ylim()
    weights = weights[detectable] if weights is not None else None
    fig, ax = plot_2D_dist(x=f_dom[detectable],
                           y=asd.to(u.Hz**(-1 / 2)),
                           weights=weights,
                           fig=fig,
                           ax=ax,
                           show=False,
                           **kwargs)
    ax.set_ylim(ylims)

    if show:
        plt.show()

    return fig, ax
Ejemplo n.º 10
0
def plot_sensitivity_curve(frequency_range=None,
                           y_quantity="ASD",
                           fig=None,
                           ax=None,
                           show=True,
                           color="#18068b",
                           fill=True,
                           alpha=0.2,
                           linewidth=1,
                           label=None,
                           **kwargs):
    """Plot the LISA sensitivity curve

    Parameters
    ----------
    frequency_range : `float array`
        Frequency values at which to plot the sensitivity curve

    y_quantity : `{{ "ASD", "h_c" }}`
        Which quantity to plot on the y axis (amplitude spectral density or characteristic strain)

    fig: `matplotlib Figure`
        A figure on which to plot the distribution. Both `ax` and `fig` must be supplied for either to be used

    ax: `matplotlib Axis`
        An axis on which to plot the distribution. Both `ax` and `fig` must be supplied for either to be used

    show : `boolean`
        Whether to immediately show the plot or only return the Figure and Axis

    color : `string or tuple`
        Colour to use for the curve, see https://matplotlib.org/tutorials/colors/colors.html for details on
        how to specify a colour

    fill : `boolean`
        Whether to fill the area below the sensitivity curve

    alpha : `float`
        Opacity of the filled area below the sensitivity curve (ignored if fill is `False`)

    linewidth : `float`
        Width of the sensitivity curve

    label : `string`
        Label for the sensitivity curve in legends

    **kwargs : `various`
        Keyword args are passed to :meth:`legwork.psd.power_spectral_density`, see those docs for details on
        possible arguments.

    Returns
    -------
    fig : `matplotlib Figure`
        The figure on which the distribution is plotted

    ax : `matplotlib Axis`
        The axis on which the distribution is plotted
    """
    if frequency_range is None:
        frequency_range = np.logspace(-5, 0, 1000) * u.Hz

    if fig is None or ax is None:
        fig, ax = plt.subplots()

    # work out what the noise amplitude should be
    PSD = psd.power_spectral_density(f=frequency_range, **kwargs)
    if y_quantity == "ASD":
        noise_amplitude = np.sqrt(PSD)
    elif y_quantity == "h_c":
        noise_amplitude = np.sqrt(frequency_range * PSD)
    else:
        raise ValueError("y_quantity must be one of 'ASD' or 'h_c'")

    # plot the curve and fill if needed
    with quantity_support():
        ax.loglog(frequency_range,
                  noise_amplitude,
                  color=color,
                  label=label,
                  linewidth=linewidth)
        if fill:
            ax.fill_between(frequency_range,
                            np.zeros_like(noise_amplitude),
                            noise_amplitude,
                            alpha=alpha,
                            color=color)

    # adjust labels, sizes and frequency limits to plot is flush to the edges
    ax.set_xlabel(r'Frequency [$\rm Hz$]')
    if y_quantity == "ASD":
        ax.set_ylabel(r'ASD $[\rm Hz^{-1/2}]$')
    else:
        ax.set_ylabel(r'Characteristic Strain')

    ax.tick_params(axis='both', which='major')
    ax.set_xlim(np.min(frequency_range).value, np.max(frequency_range).value)

    if show:
        plt.show()

    return fig, ax