Esempio n. 1
0
    def plot_samples(self,
                     samples, 
                     plot_fields=['y'],
                     start='2020-03-04',
                     T=None,
                     ax=None,          
                     legend=True,
                     forecast=False,
                     n_samples=0,
                     intervals=[50, 80, 95]):
        '''
        Plotting method for SIR-type models. 
        '''

        
        ax = plt.axes(ax)

        T_data = self.horizon(samples, forecast=forecast)        
        T = T_data if T is None else min(T, T_data) 
        
        fields = {f: 0.0 + self.get(samples, f, forecast=forecast)[:,:T] for f in plot_fields}
        names = {f: self.names[f] for f in plot_fields}
                
        medians = {names[f]: np.median(v, axis=0) for f, v in fields.items()}

        t = pd.date_range(start=start, periods=T, freq='D')

        ax.set_prop_cycle(None)
        colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

        # Plot medians
        df = pd.DataFrame(index=t, data=medians)
        df.plot(ax=ax, legend=legend)
        median_max = df.max().values

        # Plot samples if requested
        if n_samples > 0:
            for i, f in enumerate(fields):
                df = pd.DataFrame(index=t, data=fields[f][:n_samples,:].T)
                df.plot(ax=ax, legend=False, alpha=0.1)
                
        # Plot prediction intervals
        pi_max = 10
        handles = []
        for interval in intervals:
            low=(100.-interval)/2
            high=100.-low
            pred_intervals = {names[f]: np.percentile(v, (low, high), axis=0) for f, v in fields.items()}
            for i, pi in enumerate(pred_intervals.values()):
                h = ax.fill_between(t, pi[0,:], pi[1,:], alpha=0.1, color=colors[i], label=interval)
                handles.append(h)
                pi_max = np.maximum(pi_max, np.nanmax(pi[1,:]))

        
        return median_max, pi_max
Esempio n. 2
0
File: base.py Progetto: elray1/covid
    def plot_samples(self,
                     samples,
                     plot_fields=['y'],
                     start='2020-03-04',
                     T=None,
                     ax=None,
                     legend=True,
                     forecast=False):
        '''
        Plotting method for SIR-type models. 
        '''

        ax = plt.axes(ax)

        T_data = self.horizon(samples, forecast=forecast)
        T = T_data if T is None else min(T, T_data)

        fields = {
            f: self.get(samples, f, forecast=forecast)[:, :T]
            for f in plot_fields
        }
        names = {f: self.names[f] for f in plot_fields}

        medians = {names[f]: np.median(v, axis=0) for f, v in fields.items()}
        pred_intervals = {
            names[f]: np.percentile(v, (10, 90), axis=0)
            for f, v in fields.items()
        }

        t = pd.date_range(start=start, periods=T, freq='D')

        ax.set_prop_cycle(None)

        # Plot medians
        df = pd.DataFrame(index=t, data=medians)
        df.plot(ax=ax, legend=legend)
        median_max = df.max().values

        # Plot prediction intervals
        pi_max = 10
        for pi in pred_intervals.values():
            ax.fill_between(t, pi[0, :], pi[1, :], alpha=0.1, label='CI')
            pi_max = np.maximum(pi_max, np.nanmax(pi[1, :]))

        return median_max, pi_max
Esempio n. 3
0
def test_VALD_MODIT():

    #wavelength range
    wls, wll = 10395, 10405

    #Set a model atmospheric layers, wavenumber range for the model, an instrument
    NP = 100
    Parr, dParr, k = pressure_layer(NP=NP)
    Pref = 1.0  #bar
    ONEARR = np.ones_like(Parr)

    Nx = 2000
    nus, wav, res = nugrid(wls - 5.0, wll + 5.0, Nx, unit="AA", xsmode="modit")

    Rinst = 100000.  #instrumental spectral resolution
    beta_inst = R2STD(
        Rinst)  #equivalent to beta=c/(2.0*np.sqrt(2.0*np.log(2.0))*R)

    #atoms and ions from VALD
    adbV = moldb.AdbVald(
        path_ValdLineList, nus, crit=1e-100
    )  #The crit is defined just in case some weak lines may cause an error that results in a gamma of 0... (220219)
    asdb = moldb.AdbSepVald(adbV)

    #molecules from exomol
    mdbH2O = moldb.MdbExomol('.database/H2O/1H2-16O/POKAZATEL',
                             nus,
                             crit=1e-50)  #,crit = 1e-40)
    mdbTiO = moldb.MdbExomol('.database/TiO/48Ti-16O/Toto', nus,
                             crit=1e-50)  #,crit = 1e-50)
    mdbOH = moldb.MdbExomol('.database/OH/16O-1H/MoLLIST', nus)
    mdbFeH = moldb.MdbExomol('.database/FeH/56Fe-1H/MoLLIST', nus)

    #CIA
    cdbH2H2 = contdb.CdbCIA('.database/H2-H2_2011.cia', nus)

    #molecular mass
    molmassH2O = molinfo.molmass("H2O")
    molmassTiO = molinfo.molmass("TiO")
    molmassOH = molinfo.molmass("OH")
    molmassFeH = molinfo.molmass("FeH")
    molmassH = molinfo.molmass("H")
    molmassH2 = molinfo.molmass("H2")

    #Initialization of MODIT (for separate VALD species, and exomol molecules(e.g., FeH))
    cnuS, indexnuS, R, pmarray = initspec.init_modit_vald(
        asdb.nu_lines, nus, asdb.N_usp)
    cnu_FeH, indexnu_FeH, R, pmarray = initspec.init_modit(
        mdbFeH.nu_lines, nus)
    cnu_H2O, indexnu_H2O, R, pmarray = initspec.init_modit(
        mdbH2O.nu_lines, nus)
    cnu_OH, indexnu_OH, R, pmarray = initspec.init_modit(mdbOH.nu_lines, nus)
    cnu_TiO, indexnu_TiO, R, pmarray = initspec.init_modit(
        mdbTiO.nu_lines, nus)

    #sampling the max/min of temperature profiles
    fT = lambda T0, alpha: T0[:, None] * (Parr[None, :] / Pref)**alpha[:, None]
    T0_test = np.array([1500.0, 4000.0, 1500.0, 4000.0])
    alpha_test = np.array([0.2, 0.2, 0.05, 0.05])
    res = 0.2

    #Assume typical atmosphere
    H_He_HH_VMR_ref = [0.1, 0.15, 0.75]
    PH_ref = Parr * H_He_HH_VMR_ref[0]
    PHe_ref = Parr * H_He_HH_VMR_ref[1]
    PHH_ref = Parr * H_He_HH_VMR_ref[2]

    #Precomputing dgm_ngammaL
    dgm_ngammaL_VALD = setdgm_vald_all(asdb, PH_ref, PHe_ref, PHH_ref, R, fT,
                                       res, T0_test, alpha_test)
    dgm_ngammaL_FeH = setdgm_exomol(mdbFeH, fT, Parr, R, molmassFeH, res,
                                    T0_test, alpha_test)
    dgm_ngammaL_H2O = setdgm_exomol(mdbH2O, fT, Parr, R, molmassH2O, res,
                                    T0_test, alpha_test)
    dgm_ngammaL_OH = setdgm_exomol(mdbOH, fT, Parr, R, molmassOH, res, T0_test,
                                   alpha_test)
    dgm_ngammaL_TiO = setdgm_exomol(mdbTiO, fT, Parr, R, molmassTiO, res,
                                    T0_test, alpha_test)

    T0 = 3000.
    alpha = 0.07
    Mp = 0.155 * 1.99e33 / 1.90e30
    Rp = 0.186 * 6.96e10 / 6.99e9
    u1 = 0.0
    u2 = 0.0
    RV = 0.00
    vsini = 2.0

    mmw = 2.33 * ONEARR  #mean molecular weight
    log_e_H = -4.2
    VMR_H = 0.09
    VMR_H2 = 0.77
    VMR_FeH = 10**-8
    VMR_H2O = 10**-4
    VMR_OH = 10**-4
    VMR_TiO = 10**-8
    A_Fe = 1.5
    A_Ti = 1.2

    adjust_continuum = 0.99

    ga = 2478.57730044555 * Mp / Rp**2
    Tarr = T0 * (Parr / Pref)**alpha
    PH = Parr * VMR_H
    PHe = Parr * (1 - VMR_H - VMR_H2)
    PHH = Parr * VMR_H2
    VMR_e = VMR_H * 10**log_e_H

    #VMR of atoms and ions (+Abundance modification)
    mods_ID = jnp.array([[26, 1], [22, 1]])
    mods = jnp.array([A_Fe, A_Ti])
    VMR_uspecies = atomll.get_VMR_uspecies(asdb.uspecies, mods_ID, mods)
    VMR_uspecies = VMR_uspecies[:, None] * ONEARR

    #Compute delta tau

    #Atom & ions (VALD)
    SijMS, ngammaLMS, nsigmaDlS = vald_all(asdb, Tarr, PH, PHe, PHH, R)
    xsmS = xsmatrix_vald(cnuS, indexnuS, R, pmarray, nsigmaDlS, ngammaLMS,
                         SijMS, nus, dgm_ngammaL_VALD)
    dtauatom = dtauVALD(dParr, xsmS, VMR_uspecies, mmw, ga)

    #FeH
    SijM_FeH, ngammaLM_FeH, nsigmaDl_FeH = exomol(mdbFeH, Tarr, Parr, R,
                                                  molmassFeH)
    xsm_FeH = xsmatrix(cnu_FeH, indexnu_FeH, R, pmarray, nsigmaDl_FeH,
                       ngammaLM_FeH, SijM_FeH, nus, dgm_ngammaL_FeH)
    dtaum_FeH = dtauM_mmwl(dParr, jnp.abs(xsm_FeH), VMR_FeH * ONEARR, mmw, ga)

    #H2O
    SijM_H2O, ngammaLM_H2O, nsigmaDl_H2O = exomol(mdbH2O, Tarr, Parr, R,
                                                  molmassH2O)
    xsm_H2O = xsmatrix(cnu_H2O, indexnu_H2O, R, pmarray, nsigmaDl_H2O,
                       ngammaLM_H2O, SijM_H2O, nus, dgm_ngammaL_H2O)
    dtaum_H2O = dtauM_mmwl(dParr, jnp.abs(xsm_H2O), VMR_H2O * ONEARR, mmw, ga)

    #OH
    SijM_OH, ngammaLM_OH, nsigmaDl_OH = exomol(mdbOH, Tarr, Parr, R, molmassOH)
    xsm_OH = xsmatrix(cnu_OH, indexnu_OH, R, pmarray, nsigmaDl_OH, ngammaLM_OH,
                      SijM_OH, nus, dgm_ngammaL_OH)
    dtaum_OH = dtauM_mmwl(dParr, jnp.abs(xsm_OH), VMR_OH * ONEARR, mmw, ga)

    #TiO
    SijM_TiO, ngammaLM_TiO, nsigmaDl_TiO = exomol(mdbTiO, Tarr, Parr, R,
                                                  molmassTiO)
    xsm_TiO = xsmatrix(cnu_TiO, indexnu_TiO, R, pmarray, nsigmaDl_TiO,
                       ngammaLM_TiO, SijM_TiO, nus, dgm_ngammaL_TiO)
    dtaum_TiO = dtauM_mmwl(dParr, jnp.abs(xsm_TiO), VMR_TiO * ONEARR, mmw, ga)

    #Hminus
    dtau_Hm = dtauHminus_mmwl(nus, Tarr, Parr, dParr, VMR_e * ONEARR,
                              VMR_H * ONEARR, mmw, ga)

    #CIA
    dtauc_H2H2 = dtauCIA_mmwl(nus, Tarr, Parr, dParr, VMR_H2 * ONEARR,
                              VMR_H2 * ONEARR, mmw, ga, cdbH2H2.nucia,
                              cdbH2H2.tcia, cdbH2H2.logac)

    #Summations
    dtau = dtauatom + dtaum_FeH + dtaum_H2O + dtaum_OH + dtaum_TiO + dtau_Hm + dtauc_H2H2

    sourcef = planck.piBarr(Tarr, nus)
    F0 = rtrun(dtau, sourcef)
    Frot = response.rigidrot(nus, F0, vsini, u1, u2)
    wavd = jnp.linspace(wls, wll, 500)
    nusd = jnp.array(1.e8 / wavd[::-1])
    mu = response.ipgauss_sampling(nusd, nus, Frot, beta_inst, RV)
    mu = mu / jnp.nanmax(mu) * adjust_continuum

    assert (np.all(~np.isnan(mu)) * \
            np.all(mu != 0) * \
            np.all(abs(mu) != np.inf))
Esempio n. 4
0
def nanmax(x, axis=None, keepdims=None):
  if isinstance(x, JaxArray): x = x.value
  r = jnp.nanmax(x, axis=axis, keepdims=keepdims)
  return r if axis is None else JaxArray(r)
Esempio n. 5
0
def per_layer_figure(*,
                     state,
                     key_format,
                     items,
                     title,
                     xlabel,
                     ylabel,
                     show_values=False):
    """Generates a figure with a subplot per layer with consistent scales."""
    num_items = len(items)
    fig, axes = plt.subplots(nrows=1,
                             ncols=num_items,
                             figsize=(num_items * 3, 3))
    fig.suptitle(title)

    def get_value(index, item):
        if key_format:
            key = key_format.format(item)
            all_values = state[key]
            value = all_values[0]
        else:
            value = state[index]
        return value

    vmin = jnp.inf
    vmax = -jnp.inf
    for index, item in enumerate(items):
        value = get_value(index, item)
        value = jnp.where(jnp.isfinite(value), value, jnp.nan)
        vmin = jnp.minimum(vmin, jnp.nanmin(value))
        vmax = jnp.maximum(vmax, jnp.nanmax(value))

    for index, item in enumerate(items):
        if num_items == 1:
            ax = axes
        else:
            ax = axes[index]
        ax.set_title(f'Time = {index}')
        ax.set_xlabel(xlabel)
        if index == 0:
            ax.set_ylabel(ylabel)
        value = get_value(index, item)
        im = ax.imshow(value, vmin=vmin, vmax=vmax)

        if show_values and len(value) < 25:
            # Add text overlays indicating the numerical value.
            for node_index, row in enumerate(value):
                for timestep, v in enumerate(row):
                    ax.text(timestep,
                            node_index,
                            str(v),
                            horizontalalignment='center',
                            verticalalignment='center',
                            color='black')

        ax.set_aspect('equal')
    cbar_width = 0.05  # Fraction of a plot
    cbar_padding = 0.05
    half_padded_cbar_width = cbar_width + cbar_padding
    padded_cbar_width = cbar_width + 2 * cbar_padding
    fig.subplots_adjust(right=1 - padded_cbar_width /
                        (num_items + padded_cbar_width))
    cbar_ax = fig.add_axes([
        1 - half_padded_cbar_width / (num_items + padded_cbar_width),  # left
        0.15,  # bottom
        cbar_width / (num_items + padded_cbar_width),  # width
        0.7,  # top
    ])
    fig.colorbar(im, cax=cbar_ax)
    return fig
Esempio n. 6
0
if __name__ == '__main__':

    target = Target(funnel_neglog, 2, metric_fun=softabs)
    x_init = np.array([0., 0.])
    p_init = np.array([0., 0.])
    hmc = RMHMC(1000, target, x_init, p_init, seed=onp.random.randint(1, 1000))
    hmc.track = True
    target.metric_fun = softabs
    target.softabs_const = 1e0
    hmc.epsilon *= 0.05
    hmc.l *= 20
    hmc.run()
    print('is there any nan here? {}'.format(onp.any(onp.isnan(hmc.samples))))
    plt.figure(figsize=(10, 10))

    x = np.linspace(np.nanmin(hmc.path[:, 0]), np.nanmax(hmc.path[:, 0]), 501)
    y = np.linspace(np.nanmin(hmc.path[:, 1]), np.nanmax(hmc.path[:, 1]), 501)
    X, Y = np.meshgrid(x, y)
    Z = np.exp(-target.neglog((X, Y)))
    plt.imshow(Z,
               vmin=Z.min(),
               vmax=Z.max(),
               origin='lower',
               extent=[x.min(), x.max(), y.min(),
                       y.max()],
               cmap=plt.cm.gist_earth_r)
    plt.plot(hmc.path[:, 0],
             hmc.path[:, 1],
             alpha=0.2,
             linewidth=1,
             color='black')