Ejemplo n.º 1
0
def make_simultaneous_hierarchial_kde_plot(n_mean_plots=100,
                                           infolder='hierarchical-model'):
    with pm.Model():
        trace = pm.load_trace(os.path.join('uniform-traces', infolder))
    psi = trace.get_values('psi')
    mu_psi = trace.get_values('mu_psi')
    df = pd.DataFrame(
        trace.get_values('psi'),
        columns=['psi-{}'.format(i) for i in range(psi.shape[1])])
    df['mu_psi'] = mu_psi
    sigma_psi = pd.Series(trace.get_values('sigma_psi'))
    # plotting
    f, (ax0, ax1) = plt.subplots(ncols=1,
                                 nrows=2,
                                 figsize=(10, 6),
                                 dpi=200,
                                 sharex=True,
                                 gridspec_kw={'height_ratios': [3, 1]})
    plt.sca(ax0)
    df.drop('mu_psi', axis=1).plot.kde(ax=plt.gca(),
                                       linestyle='dashed',
                                       legend=None)
    plt.sca(ax1)
    mean_dist_params = pd.concat((df['mu_psi'], sigma_psi), axis=1)
    for _ in trange(n_mean_plots):
        mu, sigma = mean_dist_params.sample().values.T
        x = np.linspace(0, 50, 500)
        y = np.exp(-(x - mu)**2 /
                   (2 * sigma**2)) / np.sqrt(2 * np.pi * sigma**2)
        plt.fill_between(x, np.zeros_like(y), y, alpha=0.008, color='k')
    plt.xlim(0, 50)
    ax1.set_ylim(0, ax1.get_ylim()[1])
    plt.suptitle(r"KDE for individual galaxies' $\psi$ and $\psi_\mu$")
    plt.xlabel('Pitch angle (degrees)')
    plt.savefig('hierarchial_gal_result.png', bbox_inches='tight')
Ejemplo n.º 2
0
def load_trace(start, n_weeks):
    filename = "../data/mcmc_samples_backup/parameters_covid19_{}".format(start)
    model = load_model(start, n_weeks)
    with model:
        trace = pm.load_trace(filename)
    del model
    return trace
Ejemplo n.º 3
0
    def estimate_parameters(self, df, lat, lon, map_estimate):
        x_fourier = fourier.get_fourier_valid(df, self.modes)
        x_fourier_01 = (x_fourier + 1) / 2
        x_fourier_01.columns = ["pos" + col for col in x_fourier_01.columns]

        dff = pd.concat([df, x_fourier, x_fourier_01], axis=1)
        df_subset = dh.get_subset(dff, self.subset, self.seed, self.startdate)

        self.model = self.statmodel.setup(df_subset)

        outdir_for_cell = dh.make_cell_output_dir(self.output_dir, "traces",
                                                  lat, lon, self.variable)
        if map_estimate:
            try:
                with open(outdir_for_cell, 'rb') as handle:
                    trace = pickle.load(handle)
            except Exception as e:
                print("Problem with saved trace:", e,
                      ". Redo parameter estimation.")
                trace = pm.find_MAP(model=self.model)
                if self.save_trace:
                    with open(outdir_for_cell, 'wb') as handle:
                        free_params = {
                            key: value
                            for key, value in trace.items()
                            if key.startswith('weights') or key == 'logp'
                        }
                        pickle.dump(free_params,
                                    handle,
                                    protocol=pickle.HIGHEST_PROTOCOL)
        else:
            # FIXME: Rework loading old traces
            # print("Search for trace in\n", outdir_for_cell)
            # As load_trace does not throw an error when no saved data exists, we here
            # test this manually. FIXME: Could be improved, as we check for existence
            # of names and number of chains only, but not that the data is not corrupted.
            try:
                trace = pm.load_trace(outdir_for_cell, model=self.model)
                print(trace.varnames)
                #     for var in self.statmodel.vars_to_estimate:
                #         if var not in trace.varnames:
                #             print(var, "is not in trace, rerun sampling.")
                #             raise IndexError
                #     if trace.nchains != self.chains:
                #         raise IndexError("Sample data not completely saved. Rerun.")
                print("Successfully loaded sampled data from")
                print(outdir_for_cell)
                print("Skip this for sampling.")
            except Exception as e:
                print("Problem with saved trace:", e,
                      ". Redo parameter estimation.")
                trace = self.sample()
                # print(pm.summary(trace))  # takes too much memory
                if self.save_trace:
                    pm.backends.save_trace(trace,
                                           outdir_for_cell,
                                           overwrite=True)

        return trace, dff
Ejemplo n.º 4
0
def load_trace_window(disease, model_i, start, n_weeks):
    filename_params = "../data/mcmc_samples_backup/parameters_{}_{}".format(
        disease, start)
    model = load_model_window(disease, model_i, start, n_weeks)
    with model:
        trace = pm.load_trace(filename_params)
    del model
    return trace
Ejemplo n.º 5
0
def load_trace_by_i(disease, i):
    filename_params = "../data/mcmc_samples_backup/parameters_{}_{}".format(
        disease, i)

    model = load_model_by_i(disease, i)
    with model:
        trace = pm.load_trace(filename_params)
    del model
    return trace
Ejemplo n.º 6
0
def load_trace(disease, use_age, use_eastwest):
    filename_params = "../data/mcmc_samples/parameters_{}_{}_{}".format(disease, use_age, use_eastwest)

    model = load_model(disease, use_age, use_eastwest)
    with model:
        trace = pm.load_trace(filename_params)
    
    del model
    return trace
Ejemplo n.º 7
0
def fig4(name, func, eps):
    """Makes figure 4.

    Args:
        name (str): Descriptive name of the model. Posterior samples, statistics, and
            figures are generated and saved in a subdirectory with this name.
        func (:obj:`<class 'function'>): Function for model construction. Should
            return a formatted copy of the data.
        eps (bool): If True, saves the figures to the manuscript subdirectory in .eps
            format.

    """

    with pm.Model() as m:

        fit_model(name, func)
        trace = pm.load_trace(name)
        params = sorted(
            [p.name for p in m.deterministics if "Lambda" in p.name])

    set_fig_defaults()
    rcParams["figure.figsize"] = (3, 3 * 2)
    fig, axes = plt.subplots(5, 1, constrained_layout=True)

    for p, ax in zip(params, axes):

        vals, bins, _ = ax.hist(trace[p],
                                bins=50,
                                density=True,
                                histtype="step",
                                color="lightgray")
        ax.set_xlabel(p)
        if ax == axes[0]:
            ax.set_ylabel("Posterior density")

        start, stop = pm.stats.hpd(trace[p])
        for n, l, r in zip(vals, bins, bins[1:]):

            if l > start:
                if r < stop:
                    ax.fill_between([l, r], 0, [n, n], color="lightgray")
                elif l < stop < r:
                    ax.fill_between([l, stop], 0, [n, n], color="lightgray")
            elif l < start < r:
                ax.fill_between([start, r], 0, [n, n], color="lightgray")

        x = np.linspace(min([bins[0], 0]), max([0, bins[-1]]))
        theta = skewnorm.fit(trace[p])
        ax.plot(x, skewnorm.pdf(x, *theta), "k", label="Normal approx.")
        ax.plot(x, norm.pdf(x), "k--", label="Prior")
        ax.plot([0, 0], [skewnorm.pdf(0, *theta), norm.pdf(0)], "ko")

    fig.savefig(f"{name}/fig4.png")

    if eps is True:
        fig.savefig("manuscript/fig4.eps")
Ejemplo n.º 8
0
    def test_save_and_load(self, tmpdir_factory):
        directory = str(tmpdir_factory.mktemp('data'))
        save_dir = pm.save_trace(self.trace, directory, overwrite=True)

        assert save_dir == directory

        trace2 = pm.load_trace(directory, model=TestSaveLoad.model())

        for var in ('x', 'z'):
            assert (self.trace[var] == trace2[var]).all()
Ejemplo n.º 9
0
    def test_save_and_load(self, tmpdir_factory):
        directory = str(tmpdir_factory.mktemp('data'))
        save_dir = pm.save_trace(self.trace, directory, overwrite=True)

        assert save_dir == directory

        trace2 = pm.load_trace(directory, model=TestSaveLoad.model())

        for var in ('x', 'z'):
            assert (self.trace[var] == trace2[var]).all()
Ejemplo n.º 10
0
def load_final_trace(no_rd=False):
    if no_rd:
        filename_params = "../data/mcmc_samples_backup/parameters_covid19_final_no_rd"
    else:
        filename_params = "../data/mcmc_samples_backup/parameters_covid19_final"
    model = load_final_model(no_rd=no_rd)
    with model:
        trace = pm.load_trace(filename_params)
    del model
    return trace
Ejemplo n.º 11
0
    def load(cls, input_file):
        saved_result = pd.read_pickle(input_file)
        trace_fname = saved_result.pop('trace', None)

        bhsm = cls(saved_result['galaxies'])
        if trace_fname is not None:
            with bhsm.model:
                trace = pm.load_trace(trace_fname)
        else:
            trace = None
        return {**saved_result, 'bhsm': bhsm, 'trace': trace}
Ejemplo n.º 12
0
def get_trace(path_trace, rhog, cs):
    path_results = '../../../../results/'
    data = pickle.load(open(path_results + 'data2ch.pickle', 'rb'))
    dtslip, tiltsty, tiltstx, tiltsly, tiltslx, gps, stackx, stacky, tstack, xsh, ysh, dsh, xshErr, yshErr, dshErr, strsrc, strsrcErr = data
    n = np.arange(1, len(tiltsty) + 1)
    tilt_std = 1e-5
    dt_std = 3600
    gps_std = 1
    with pm.Model() as model:
        gpsconst = pm.Uniform('gpsconst', lower=-15, upper=15)
        A_mod = pm.Uniform('A_mod', lower=0, upper=1e+2)
        B_mod = pm.Uniform('B_mod', lower=0, upper=1e+2)
        C_mod = pm.Uniform('C_mod', lower=0, upper=1e+2)
        D_mod = pm.Uniform('D_mod', lower=0, upper=1e+2)
        E_mod = pm.Uniform('E_mod', lower=0, upper=1e+2)
        F_mod = pm.Uniform('F_mod', lower=0, upper=1e+2)

        Vd_exp = pm.Uniform('Vd_exp', lower=8, upper=11)
        Vd_mod = pm.Deterministic('Vd_mod', 10**Vd_exp)

        kd_exp = pm.Uniform('kd_exp', lower=7, upper=10)
        kd_mod = pm.Deterministic('kd_mod', 10**kd_exp)
        R1 = pm.Deterministic('R1', rhog * Vd_mod / (kd_mod * S))
        ratio = pm.Uniform('ratio',
                           lower=30 * 4 * R1 / (1 + R1),
                           upper=100 * 4 * R1 / (1 + R1))
        pspd_mod = pm.Uniform('pspd_mod', lower=1e+5, upper=1e+7)
        ptps_mod = pm.Deterministic('ptps_mod', ratio * pspd_mod)
        conds_mod = pm.Uniform('conds_mod', lower=1, upper=10)

        deltap_mod = pm.Uniform('deltap', lower=1e+5, upper=ptps_mod)
        strsrc_mod = pm.Normal('strsrc_Mod', mu=strsrc, sigma=strsrcErr)
        Vs_mod = pm.Deterministic('Vs_mod', strsrc_mod / deltap_mod)

        #conds_mod = pm.Uniform('conds_mod',lower=1,upper=10)
        condd_mod = pm.Uniform('condd_mod', lower=1, upper=30)
        dsh_mod = pm.Normal('dsh_mod', mu=dsh, sigma=dshErr)
        xsh_mod = pm.Normal('xsh_mod', mu=xsh, sigma=xshErr)
        ysh_mod = pm.Normal('ysh_mod', mu=ysh, sigma=yshErr)
        coeffx = cs * dsh_mod * (
            x - xsh_mod) / (dsh_mod**2 + (x - xsh_mod)**2 +
                            (y - ysh_mod)**2)**(5. / 2) * Vd_mod
        coeffy = cs * dsh_mod * (
            y - ysh_mod) / (dsh_mod**2 + (x - xsh_mod)**2 +
                            (y - ysh_mod)**2)**(5. / 2) * Vd_mod
        tau2 = 8 * mu * ld * Vs_mod / (3.14 * condd_mod**4 * kd_mod
                                       )  #Model set-up
        x_mod = gpsconst + 4 * R1 / rhog * pspd_mod / (1 + R1) * n

        pslip_mod = -rhog * x_mod
        pstick_mod = pslip_mod + 4 * pspd_mod / (1 + R1)

        trace2 = pm.load_trace(path_results + 'trace300000_strsrc_LF')
    return trace2
Ejemplo n.º 13
0
    def test_save_and_load(self, tmpdir_factory):
        directory = str(tmpdir_factory.mktemp("data"))
        save_dir = pm.save_trace(self.trace, directory, overwrite=True)

        assert save_dir == directory

        trace2 = pm.load_trace(directory, model=TestSaveLoad.model())

        for var in ("x", "z"):
            assert (self.trace[var] == trace2[var]).all()

        assert self.trace.stat_names == trace2.stat_names
        for stat in self.trace.stat_names:
            assert all(self.trace[stat] == trace2[stat]), (
                "Restored value of statistic %s does not match stored value" %
                stat)
Ejemplo n.º 14
0
    def test_save_new_model(self, tmpdir_factory):
        directory = str(tmpdir_factory.mktemp("data"))
        save_dir = pm.save_trace(self.trace, directory, overwrite=True)

        assert save_dir == directory
        with pm.Model() as model:
            w = pm.Normal("w", 0, 1)
            new_trace = pm.sample(return_inferencedata=False)

        with pytest.raises(OSError):
            _ = pm.save_trace(new_trace, directory)

        _ = pm.save_trace(new_trace, directory, overwrite=True)
        with model:
            new_trace_copy = pm.load_trace(directory)

        assert (new_trace["w"] == new_trace_copy["w"]).all()
Ejemplo n.º 15
0
    def test_save_new_model(self, tmpdir_factory):
        directory = str(tmpdir_factory.mktemp('data'))
        save_dir = pm.save_trace(self.trace, directory, overwrite=True)

        assert save_dir == directory
        with pm.Model() as model:
            w = pm.Normal('w', 0, 1)
            new_trace = pm.sample()

        with pytest.raises(OSError):
            _ = pm.save_trace(new_trace, directory)

        _ = pm.save_trace(new_trace, directory, overwrite=True)
        with model:
            new_trace_copy = pm.load_trace(directory)

        assert (new_trace['w'] == new_trace_copy['w']).all()
Ejemplo n.º 16
0
    def test_save_new_model(self, tmpdir_factory):
        directory = str(tmpdir_factory.mktemp('data'))
        save_dir = pm.save_trace(self.trace, directory, overwrite=True)

        assert save_dir == directory
        with pm.Model() as model:
            w = pm.Normal('w', 0, 1)
            new_trace = pm.sample()

        with pytest.raises(OSError):
            _ = pm.save_trace(new_trace, directory)

        _ = pm.save_trace(new_trace, directory, overwrite=True)
        with model:
            new_trace_copy = pm.load_trace(directory)

        assert (new_trace['w'] == new_trace_copy['w']).all()
Ejemplo n.º 17
0
def load_trace(dir_path, bX, by):
    with pm.Model() as model:  # noqa
        length = pm.Gamma("length", alpha=2, beta=1)
        eta = pm.HalfCauchy("eta", beta=5)

        cov = eta**2 * pm.gp.cov.Matern52(input_dim=1, ls=length)
        gp = pm.gp.Latent(cov_func=cov)

        f = gp.prior("f", X=bX)

        sigma = pm.HalfCauchy("sigma", beta=5)
        nu = pm.Gamma("nu", alpha=2, beta=0.1)
        y_ = pm.StudentT("y", mu=f, lam=1.0 / sigma, nu=nu,
                         observed=by)  # noqa

        trace = pm.load_trace(dir_path)
        return model, gp, trace
Ejemplo n.º 18
0
def table2(name, func, tex):
    """Makes table 2.

    Args:
        name (str): Descriptive name of the model. Posterior samples, statistics, and
            figures are generated and saved in a subdirectory with this name.
        func (:obj:`<class 'function'>): Function for model construction. Should
            return a formatted copy of the data.
        tex (bool): If True, saves the table to the manuscript subdirectory.

    """

    with pm.Model() as m:
        fit_model(name, func)
        trace = pm.load_trace(name)
        params = sorted([p.name for p in m.deterministics if "Lambda" in p.name])
        df = pm.summary(trace, var_names=params)

    table = []
    for p, i in zip(params, interps):

            theta = skewnorm.fit(trace[p])
            p0 = norm.pdf(0)
            p1 = skewnorm.pdf(0, *theta)
            bf = p0 / p1
            a, b, c = df.loc[p, ["mean", "hpd_2.5", "hpd_97.5"]]

            dic = {
                "Variable": p,
                "Posterior mean (95% HPD)": "%s (%s, %s)" % (
                    latexify(a), latexify(b), latexify(c)),
                "During roved-frequency trials ...": i,
                "BF": latexify(bf),
                "Evidence": interpret(bf),
            }
            table.append(dic)
            # print(p, bf)

    df = pd.DataFrame(table)[dic.keys()]
    df.to_latex(f"{name}/table2.tex", escape=False, index=False)

    if tex is True:
        df.to_latex("manuscript/table2.tex", escape=False, index=False)
Ejemplo n.º 19
0
    def test_sample_posterior_predictive(self, tmpdir_factory):
        directory = str(tmpdir_factory.mktemp("data"))
        save_dir = pm.save_trace(self.trace, directory, overwrite=True)

        assert save_dir == directory

        rng = np.random.RandomState(10)

        with TestSaveLoad.model(rng_seeder=rng):
            ppc = pm.sample_posterior_predictive(self.trace)

        rng = np.random.RandomState(10)

        with TestSaveLoad.model(rng_seeder=rng):
            trace2 = pm.load_trace(directory)
            ppc2 = pm.sample_posterior_predictive(trace2)

        for key, value in ppc.items():
            assert (value == ppc2[key]).all()
Ejemplo n.º 20
0
    def test_sample_ppc(self, tmpdir_factory):
        directory = str(tmpdir_factory.mktemp('data'))
        save_dir = pm.save_trace(self.trace, directory, overwrite=True)

        assert save_dir == directory

        seed = 10
        np.random.seed(seed)
        with TestSaveLoad.model():
            ppc = pm.sample_ppc(self.trace)

        seed = 10
        np.random.seed(seed)
        with TestSaveLoad.model():
            trace2 = pm.load_trace(directory)
            ppc2 = pm.sample_ppc(trace2)

        for key, value in ppc.items():
            assert (value == ppc2[key]).all()
Ejemplo n.º 21
0
    def test_sample_ppc(self, tmpdir_factory):
        directory = str(tmpdir_factory.mktemp('data'))
        save_dir = pm.save_trace(self.trace, directory, overwrite=True)

        assert save_dir == directory

        seed = 10
        np.random.seed(seed)
        with TestSaveLoad.model():
            ppc = pm.sample_ppc(self.trace)

        seed = 10
        np.random.seed(seed)
        with TestSaveLoad.model():
            trace2 = pm.load_trace(directory)
            ppc2 = pm.sample_ppc(trace2)

        for key, value in ppc.items():
            assert (value == ppc2[key]).all()
Ejemplo n.º 22
0
def sampleandsave(f):
    """Sample from the model in context.

    """
    if not exists(f):

        # sample and save

        trace = pm.sample(8000, tune=2000, chains=1)
        pm.save_trace(trace, f)
        pm.traceplot(trace, compact=True)
        rcParams["font.size"] = 14
        plt.savefig(f"{f}/traceplot.png")
        ppc = pm.sample_posterior_predictive(trace)["$Y$"]
        np.savez_compressed(f"{f}/ppc.npz", ppc)

    else:

        trace = pm.load_trace(f)

    return trace
Ejemplo n.º 23
0
import corner
import exoplanet as xo
import matplotlib.pyplot as plt
import pymc3 as pm
import os

import src.close.rv.model as m
import src.data as d
from src.constants import *

plotdir = "figures/close/rv/"

if not os.path.isdir(plotdir):
    os.makedirs(plotdir)

trace = pm.load_trace(directory="chains/close/rv", model=m.model)

# view summary
df = az.summary(trace, var_names=m.all_vars)
print(df)

# write summary to disk
f = open(f"{plotdir}summary.txt", "w")
df.to_string(f)
f.close()

with az.rc_context(rc={"plot.max_subplots": 60}):
    # autocorrelation
    az.plot_autocorr(trace, var_names=m.sample_vars)
    plt.savefig(f"{plotdir}autocorr.png")
Ejemplo n.º 24
0
 def test_bad_load(self, tmpdir_factory):
     directory = str(tmpdir_factory.mktemp("data"))
     with pytest.raises(pm.TraceDirectoryError):
         pm.load_trace(directory, model=TestSaveLoad.model())

annInput = theano.shared(XsTrain)
annTarget = theano.shared(YsTrain)
errAnnInput = theano.shared(errXsTrain)
errAnnTarget = theano.shared(errYsTrain)

neural_network = construct_nn(annInput, errAnnInput, annTarget, errAnnTarget)

print("Starting the training of the BNN...")

if not os.path.exists(cache_file_bnn):

    with neural_network:
        #fit model
        trace = pm.sample(draws=nsamples,
                          init='advi+adapt_diag',
                          n_init=ninit,
                          tune=ninit // 2,
                          chains=nchains,
                          cores=ncores,
                          nuts_kwargs={'target_accept': 0.90},
                          discard_tuned_samples=True,
                          compute_convergence_checks=True,
                          progressbar=False)
    pm.save_trace(trace, directory=cache_file_bnn)
else:
    trace = pm.load_trace(cache_file_bnn, model=neural_network)

print("Done...")
Ejemplo n.º 26
0
    phi = 1 * Vs_mod / Vd_mod
    #phi = ratio  * Vs_mod / kd_mod
    stack_mod  = A_mod * tt.exp(tstack/tau2*(-T1/2 - phi/2 + tt.sqrt(4*phi + (-T1 + phi - 1)**2)/2 - 1/2)) + B_mod * tt.exp(tstack/tau2*(-T1/2 - phi/2 - tt.sqrt(4*phi + (-T1 + phi - 1)**2)/2 - 1/2))  - E_mod
#    stack_mod = A_mod * tt.exp(t*(-condd**4*pi*r/(16*ld*mu) + tt.sqrt((condd**4*pi*r/(8*ld*mu) - condd**4*ks*pi/(8*Vs*ld*mu) - conds**4*ks*pi/(8*Vs*ls*mu))**2 + condd**8*ks*pi**2*r/(16*Vs*ld**2*mu**2))/2 - condd**4*ks*pi/(16*Vs*ld*mu) - conds**4*ks*pi/(16*Vs*ls*mu))) +
#                B_mod * tt.exp
    #Posterio
    tslx_obs = pm.Normal('tslx_obs', mu=tslx_mod, sigma = tilt_std, observed=tiltslx)
    tsly_obs = pm.Normal('tsly_obs', mu=tsly_mod, sigma = tilt_std, observed=tiltsly)
    tstx_obs = pm.Normal('tstx_obs', mu=tstx_mod, sigma = tilt_std, observed=tiltstx)
    tsty_obs = pm.Normal('tsty_obs', mu=tsty_mod, sigma = tilt_std, observed=tiltsty)
    
    
    
    x_obs = pm.Normal('x_obs', mu = x_mod, sigma = gps_std, observed=gps)
    stack_obs = pm.Normal('stack_obs', mu = stack_mod, sigma = tilt_std*1e+6, observed=stack)
    trace2 = pm.load_trace(path_results + 'trace300000_LF')
panda_trace = pm.backends.tracetab.trace_to_dataframe(trace2)
panda_trace['Vs_mod'] = np.log10(panda_trace['Vs_mod']) 
panda_trace['kd_mod'] = np.log10(panda_trace['kd_mod']) 
panda_trace['Vd_mod'] = np.log10(panda_trace['Vd_mod']) 
panda_trace['pspd_mod'] = panda_trace['pspd_mod'] / 1e+6  
filename = 'res300000_LF.pickle'
results =pickle.load(open(path_results + filename,'rb'))

R1 = rho * g * results['MAP']['Vs_mod'] /(results['MAP']['kd_mod'] * S)    
xMAP = results['MAP']['gpsconst']+ 4 * R1 / (rho * g) * results['MAP']['pspd_mod'] / (1 + R1) * n
pslipMAP =  -rho * g * xMAP
pstickMAP = pslipMAP + 4 * results['MAP']['pspd_mod']/ (1 + R1)
T1 = (conds_mod / results['MAP']['condd_mod'] )**4 * ld /ls
phi = 1 * results['MAP']['Vs_mod'] / results['MAP']['Vd_mod']
coeffx = cs * results['MAP']['dsh_mod'] * (x -  results['MAP']['xsh_mod']) / (results['MAP']['dsh_mod']**2 + (x -  results['MAP']['xsh_mod'])**2 + (y -  results['MAP']['ysh_mod'])**2 )**(5./2) * results['MAP']['Vs_mod']
def forecast_main(clusters, cases_df, vel_cases_df, population_df,
                  cluster_mode, init_cluster_num, max_cluster_num,
                  initial_date, final_date, final_change_date, num_days_future,
                  dataset_final_date, run_mode, root_save_path):
    rmse_per_cluster_list = []
    total_re_per_cluster_list = []
    mean_rsquared_per_cluster_list = []
    mape_per_cluster_list = []
    wape_per_cluster_list = []

    mse_per_county_per_cluster_list = []
    re_per_county_per_cluster_list = []
    rsquared_per_county_per_cluster_list = []
    mape_per_county_per_cluster_list = []
    wape_per_county_per_cluster_list = []

    unclustered_rmse_per_cluster_list = []
    unclustered_mse_per_county_per_cluster_list = []
    unclustered_total_re_per_cluster_list = []
    unclustered_re_per_county_per_cluster_list = []

    initial_date = datetime.datetime.strptime(f'{initial_date}/2020',
                                              '%m/%d/%Y')
    final_date = datetime.datetime.strptime(f'{final_date}/2020', '%m/%d/%Y')
    dataset_final_date = datetime.datetime.strptime(
        f'{dataset_final_date}/2020', '%m/%d/%Y')

    cluster_colors = {
        1: rgb2hex(255, 0, 0),
        2: rgb2hex(255, 111, 0),
        3: rgb2hex(255, 234, 0),
        4: rgb2hex(151, 255, 0),
        5: rgb2hex(44, 255, 150),
        6: rgb2hex(0, 152, 255),
        7: rgb2hex(0, 25, 255)
    }

    # cluster id 0 was not clustered by Tzu Hsi but I still use it
    for cluster_id in range(init_cluster_num, max_cluster_num):
        if cluster_mode == 'unclustered':
            cluster_id = 'All'
            chosen_cluster_series = clusters
        else:
            chosen_cluster_series = clusters[clusters == cluster_id]
        cluster_counties = chosen_cluster_series.index.tolist()

        print('-----------------------------')
        print('Cluster ID: ', cluster_id)
        # ------------- Create save folders --------------
        cluster_save_path = root_save_path + f'/cluster_{cluster_id}/'
        if os.path.isdir(cluster_save_path) is False:
            os.mkdir(cluster_save_path)
        cluster_all_save_path = root_save_path + f'/cluster_All/'

        # -------------- Data Preprocessing --------------
        cluster_cases_df, proc_population_series = preprocess_dataset(
            cases_df.copy(), population_df.copy(), cluster_counties)
        cluster_vel_cases_df, _ = preprocess_dataset(vel_cases_df.copy(),
                                                     population_df.copy(),
                                                     cluster_counties)

        cluster_cases_df, current_cumulative_cases_df, future_cumulative_cases_df, old_cumulative_infected_cases_series, date_begin_sim, num_days_sim = \
            process_date(cluster_cases_df, initial_date, final_date, dataset_final_date, num_days_future)
        cluster_vel_cases_df, current_vel_cases_df, future_vel_cases_df, _, _, _ = \
            process_date(cluster_vel_cases_df, initial_date, final_date, dataset_final_date, num_days_future)

        current_cumulative_cases_series = current_cumulative_cases_df.sum(
            axis=1)
        current_vel_cases_series = current_vel_cases_df.sum(axis=1)
        cluster_total_population = proc_population_series.sum()
        future_cumulative_cases = future_cumulative_cases_df.sum(axis=1)[-1]

        print('old_cumulative_infected_cases_series:',
              old_cumulative_infected_cases_series)
        print('Cumulative future cases:', future_cumulative_cases)
        print('population:', cluster_total_population)
        print('Remaining population:',
              cluster_total_population - future_cumulative_cases)

        visualize_trend_with_r_not(cluster_id, cluster_vel_cases_df,
                                   cluster_save_path)

        # --------------- Get SIR Model -----------------
        # convert cumulative infected to daily total infected cases
        current_total_cases_series = current_cumulative_cases_series - old_cumulative_infected_cases_series.sum(
        )
        future_total_cases_df = future_cumulative_cases_df - old_cumulative_infected_cases_series

        day_1_cumulative_infected_cases = current_cumulative_cases_series[0]
        S_begin_beta = cluster_total_population - day_1_cumulative_infected_cases
        I_begin_beta = current_total_cases_series[
            0]  # day 1 total infected cases

        print('day_1_cumulative_infected_cases: ',
              day_1_cumulative_infected_cases)
        print('S_begin_beta: ', S_begin_beta)
        print('I_begin_beta: ', I_begin_beta)

        change_points = get_change_points(final_date, final_change_date,
                                          cluster_id)
        sir_model = Bayesian_Inference_SEIR.SIR_with_change_points(
            S_begin_beta,
            I_begin_beta,
            current_vel_cases_series.to_numpy(
            ),  # current_total_cases_series.to_numpy(),
            change_points_list=change_points,
            date_begin_simulation=date_begin_sim,
            num_days_sim=num_days_sim,
            diff_data_sim=0,
            N=cluster_total_population)

        # ---------- Estimate Parameters for SIR model ------------
        if run_mode == 'train':
            Bayesian_Inference_SEIR.run(sir_model,
                                        N_SAMPLES=10000,
                                        cluster_save_path=cluster_save_path)

        elif run_mode == 'eval':
            trace = pm.load_trace(cluster_save_path + 'sir_model.trace',
                                  model=sir_model)
            susceptible_series = proc_population_series - current_cumulative_cases_df.loc[
                final_date]

            # ---------- Forecast using unclustered data ------------------
            t = range(len(future_vel_cases_df.values))
            if cluster_mode == 'clustered':
                lambda_t, μ = np.load(cluster_all_save_path + 'SIR_params.npy',
                                      allow_pickle=True)
                beta, gamma = lambda_t[-1], μ[0]
                print('beta, gamma', beta, gamma)
                cluster_all_vel_case_forecast = sir_forecast_a_county(
                    susceptible_series.sum(),
                    moving_average_from_df(current_vel_cases_df).sum(),
                    cluster_total_population, t, beta, gamma, '', '')
            else:
                cluster_all_vel_case_forecast = None

            # ----------- Forecast using clustered data ------------------
            lambda_t, μ = np.load(cluster_save_path + 'SIR_params.npy',
                                  allow_pickle=True)
            beta, gamma = lambda_t[-1], μ[0]
            print('beta, gamma', beta, gamma)
            cluster_forecast_I0 = np.mean(
                trace['new_cases'][:, len(current_vel_cases_series)], axis=0)

            cluster_vel_case_forecast = sir_forecast_a_county(
                susceptible_series.sum(), cluster_forecast_I0,
                cluster_total_population, t, beta, gamma, '', '')

            # ----------- Forecast Visualization ---------------
            # reorder cluster id to 1,2,3,4,5,6,7 based on severity (rising trend)
            if cluster_id == 0:
                cluster_id = 7
            elif cluster_id == 2:
                cluster_id = 1
            elif cluster_id == 3:
                cluster_id = 3
            elif cluster_id == 4:
                cluster_id = 2
            elif cluster_id == 6:
                cluster_id = 5
            elif cluster_id == 7:
                cluster_id = 6
            elif cluster_id == 8:
                cluster_id = 4

            if cluster_id in cluster_colors.keys():
                cluster_color = cluster_colors[cluster_id]
            else:
                cluster_color = 'black'
            plot_cases(cluster_id,
                       cluster_color,
                       trace,
                       current_vel_cases_series,
                       future_vel_cases_df,
                       cluster_vel_case_forecast,
                       cluster_all_vel_case_forecast,
                       date_begin_sim,
                       diff_data_sim=0,
                       num_days_future=num_days_future,
                       cluster_save_path=cluster_save_path)

            # ---------- Evaluation per county -----------
            cluster_mse_dict, cluster_re_dict, cluster_rsquared_dict, cluster_mape_dict, cluster_wape_dict = \
                eval_per_cluster(susceptible_series, moving_average_from_df(current_vel_cases_df), future_vel_cases_df, proc_population_series, num_days_future,cluster_save_path)

            if cluster_mode == 'unclustered':
                for cluster_id in range(0, max_cluster_num):
                    local_mse_list = []
                    local_re_list = []
                    chosen_cluster_series = clusters[clusters == cluster_id]
                    cluster_counties = chosen_cluster_series.index.tolist()
                    for a_county in cluster_counties:
                        if a_county in cluster_mse_dict:
                            local_mse_list.append(cluster_mse_dict[a_county])
                        if a_county in cluster_re_dict:
                            local_re_list.append(cluster_re_dict[a_county])
                    unclustered_rmse_per_cluster_list.append(
                        math.sqrt(statistics.mean(local_mse_list)))
                    unclustered_mse_per_county_per_cluster_list.append(
                        local_mse_list)
                    unclustered_total_re_per_cluster_list.append(
                        statistics.mean(local_re_list))
                    unclustered_re_per_county_per_cluster_list.append(
                        local_re_list)

            elif cluster_mode == 'clustered':
                rmse_per_cluster_list.append(
                    math.sqrt(statistics.mean(cluster_mse_dict.values())))
                total_re_per_cluster_list.append(
                    statistics.mean(cluster_re_dict.values()))
                mean_rsquared_per_cluster_list.append(
                    statistics.mean(cluster_rsquared_dict.values()))
                # mape_per_cluster_list.append(statistics.mean(cluster_mape_dict.values()))
                wape_per_cluster_list.append(
                    statistics.mean(cluster_wape_dict.values()))

                mse_per_county_per_cluster_list.append(
                    list(cluster_mse_dict.values()))
                re_per_county_per_cluster_list.append(
                    list(cluster_re_dict.values()))
                rsquared_per_county_per_cluster_list.append(
                    list(cluster_rsquared_dict.values()))
                mape_per_county_per_cluster_list.append(
                    list(cluster_mape_dict.values()))
                wape_per_county_per_cluster_list.append(
                    list(cluster_wape_dict.values()))

        if cluster_mode == 'unclustered':
            break  # only run once for unclustered dataset

    return rmse_per_cluster_list, mse_per_county_per_cluster_list, mean_rsquared_per_cluster_list, rsquared_per_county_per_cluster_list, \
           mape_per_cluster_list, mape_per_county_per_cluster_list, wape_per_cluster_list, wape_per_county_per_cluster_list, \
           total_re_per_cluster_list, re_per_county_per_cluster_list, \
           unclustered_rmse_per_cluster_list, unclustered_mse_per_county_per_cluster_list, \
           unclustered_total_re_per_cluster_list, unclustered_re_per_county_per_cluster_list
Ejemplo n.º 28
0
def get_SIR(x, y, y0, country, forecast_len=0, load_post=False):
    '''
    If 'forecast_len' is nonzero, attempts to load a trace corresponding to the
    country of interest from the directory 'traces' and retrieves predicted numbers
    of infected and susceptible patients 'forecast_len' days into the future after the 
    1st case is detected in the country.
    '''

    # If in 'prediction mode', modify x, y to reflect forecast length
    if forecast_len != 0:
        ext = np.arange(1, forecast_len + 1).astype(float)
        ext += x[-1]
        x = np.append(x, ext)
        y = np.empty((x.shape[0], y.shape[1]))

    # SIR Model
    # p[0]: beta, p[1]: lambda
    def SIR(y, t, p):
        ds = -p[0] * y[0] * y[1]  # Susceptible differential
        di = p[0] * y[0] * y[1] - p[1] * y[1]  # Infected differential
        return [ds, di]

    # Initialize ODE
    sir_ode = DifferentialEquation(func=SIR,
                                   times=x,
                                   n_states=2,
                                   n_theta=2,
                                   t0=0)

    load_dir = osp.join('traces', country.lower())

    with pm.Model() as model:
        sigma = pm.HalfNormal('sigma', 3, shape=2)

        # R0 is bounded below by 1 because we see an epidemic has occured
        R0 = pm.Normal('R0', 2, 3)

        lmbda = pm.Normal('lambda', 0.1, 0.1)

        beta = pm.Deterministic('beta', lmbda * R0)

        print('Setting up model for ' + country)
        sir_curves = sir_ode(y0=y0, theta=[beta, lmbda])

        y_obs = pm.Normal('y_obs', mu=sir_curves, sigma=sigma, observed=y)

        if forecast_len == 0:
            trace = pm.sample(2000,
                              tune=1000,
                              cores=2,
                              chains=2,
                              progressbar=True)

            # Save trace
            pm.save_trace(trace, load_dir, overwrite=True)

            # Get the posterior
            post = pm.sample_posterior_predictive(trace, progressbar=True)

            out_post = post
        else:
            # Load trace
            print('Loading trace')
            trace = pm.load_trace(load_dir)

            print('Computing posterior')
            #Get posterior
            if not load_post:
                post = pm.sample_posterior_predictive(trace[500:],
                                                      progressbar=True)
                out_post = post
                with open(country + '_post.pkl', 'wb') as buff:
                    pickle.dump({'post': post}, buff)

            else:
                with open(country + '_post.pkl', 'rb') as buff:
                    data = pickle.load(buff)
                out_post = data['post']

    print('Done')

    return trace, out_post, x
Ejemplo n.º 29
0
def main():

    # load the data
    df = pd.read_csv("../../assets/data/HS.csv", index_col=0)

    # define items to keep
    item_names = [
        "visual",
        "cubes",
        "paper",
        "flags",
        "general",
        "paragrap",
        "sentence",
        "wordc",
        "wordm",
        "addition",
        "code",
        "counting",
        "straight",
        "wordr",
        "numberr",
        "figurer",
        "object",
        "numberf",
        "figurew",
    ]

    # define the factor structure
    factors = np.array([
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
    ])

    paths = np.array([
        [0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0],
    ])

    # iterate over the two schools
    for school, sdf in df.groupby("school"):

        # define the path to save results
        f = f"../data/BSEM examples/{school}"

        # select the 19 commonly used variables
        items = sdf[item_names]

        # for numerical convenience, standardize the data
        items = (items - items.mean()) / items.std()

        with pm.Model():

            # construct the model
            bsem(items, factors, paths)

            if not exists(f):

                # sample and save
                trace = pm.sample(chains=2)  # 19000, tune=1000,
                pm.save_trace(trace, f)

            else:

                trace = pm.load_trace(f)

        pm.traceplot(trace, compact=True)
        rcParams["font.size"] = 14
        plt.savefig(f"{f}/traceplot.png")

        # create a nice summary table
        loadings = pd.DataFrame(
            trace[r"$\Lambda$"].mean(axis=0).round(3),
            index=[v.title() for v in item_names],
            columns=["Spatial", "Verbal", "Speed", "Memory", "g"],
        )
        loadings.to_csv(f"{f}/loadings.csv")
        print(tabulate(loadings, tablefmt="pipe", headers="keys"))
        #
        # # correlations = pd.DataFrame(
        # #     trace[r"$\Psi$"].mean(axis=0).round(3),
        # #     index=["Spatial", "Verbal", "Speed", "Memory", "g"],
        # #     columns=["Spatial", "Verbal", "Speed", "Memory", "g"],
        # # )
        # # correlations.to_csv(f"{f}/factor_correlations.csv")
        #
        _paths = pd.DataFrame(
            trace[r"$\Gamma$"].mean(axis=0).round(3),
            index=["Spatial", "Verbal", "Speed", "Memory", "g"],
            columns=["Spatial", "Verbal", "Speed", "Memory", "g"],
        )
        _paths.to_csv(f"{f}/factor_paths.csv")
        print(tabulate(_paths, tablefmt="pipe", headers="keys"))
Ejemplo n.º 30
0
def show_plot():
    plt.close('all')
    plt.ioff()

    # load in the trace
    from glob import glob
    import os
    global df, trace_o, trace_n, title
    df = pd.read_csv('merged_data.csv')

    # read in the models
    fit_o = models.oocyte_model(df)
    fit_n = models.oocyte_model(df)

    # read in the traces
    trace_o = pm.load_trace('trace_o', fit_o)
    trace_n = pm.load_trace('trace_n', fit_n)

    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    out = widgets.Output()

    #res_df = pd.read_pickle("res.pkl")
    def click(b):
        global o_obs, o_x, n_obs, n_x, df

        mask = df['i_ind'] == choice.value
        if (sum(mask) > 0):
            title_comp = df.loc[mask, 'ucode'].values[0].split('_')
            fig.suptitle("{}".format(title_comp[0]))
            ax.set_title("day {}.rep {}, str {}".format(*title_comp[1:]))

        Vmax_o = Vmax_o_slide.value
        ao = ao_slide.value
        to = to_slide.value
        Vmin_o = Vmin_o_slide.value

        Vmax_n = Vmax_n_slide.value
        a0 = a0_slide.value
        t0 = t0_slide.value
        a1 = a1_slide.value
        t1 = t1_slide.value
        Vmin_n = Vmin_n_slide.value
        #plt.figure(2)

        ax.clear()
        ax.plot(-o_x, o_obs, 'o', color='red')
        ax.plot(-n_x, n_obs, 'o', color='blue')
        if (True):
            x = np.linspace(-2, 16, num=100)
            ax.plot(-x,
                    bu.rise_only(x, Vmax_o, to, ao, 1.0, Vmin_o),
                    color='red')
            ax.plot(-x,
                    bu.rise_and_fall(x, Vmax_n, t0, a0, 1.0, t1, a1, 1,
                                     Vmin_n),
                    color='blue')

            ax.plot([-to, -to],
                    [0, bu.rise_only(to, Vmax_o, to, ao, 1.0, Vmin_o)],
                    ls='--',
                    color='red',
                    lw=1.5)
            ax.plot([-t0, -t0], [
                0,
                bu.rise_and_fall(t0, Vmax_n, t0, a0, 1.0, t1, a1, 1, Vmin_n)
            ],
                    ls='--',
                    color='blue',
                    lw=1.5)
            ax.plot([-t1, -t1], [
                0,
                bu.rise_and_fall(t1, Vmax_n, t0, a0, 1.0, t1, a1, 1, Vmin_n)
            ],
                    ls='--',
                    color='blue',
                    lw=1.5)

            ax.set_xticks(range(-14, 2, 2))
            ax.set_ylim(0, 200)
            ax.axhline(Vmax_o, ls='--', color='red')
            ax.axvline(0, ls='--', color='grey', lw=1)

            [xmin, xmax] = ax.get_xlim()
            [ymin, ymax] = ax.get_ylim()
            ax.annotate(r'$t_o$', (to + 0.2, ymax - 10))
            ax.annotate(r'$Vmax_{o}}$', (xmax - 2, Vmax_o + 10))

        with out:
            clear_output(wait=True)
            display(ax.figure)

    #choice=Dropdown(
    #
    #	options='cont_r1 cont_r2 cont_r3 caged_d04_r1 caged_d07_r1'.split(),
    #	value='cont_r1',
    #	description='Number:',
    #	disabled=False,
    #)
    choice = widgets.IntText(value=0,
                             min=0,
                             max=len(df['i_ind'].unique()),
                             step=1,
                             description='Test:',
                             disabled=False,
                             continuous_update=False,
                             readout=True,
                             readout_format='d')

    Vmax_o_slide = FloatSlider(description=r'$V$max$_o$',
                               value=150,
                               min=0,
                               max=300,
                               continuous_update=False)
    Vmax_o_slide.observe(click, names='value')
    Vmin_o_slide = FloatSlider(description=r'$V$min$_o$',
                               value=15,
                               min=0,
                               max=30,
                               continuous_update=False)
    Vmin_o_slide.observe(click, names='value')
    ao_slide = FloatSlider(description=r'$a_o$',
                           value=0.2,
                           min=0.,
                           max=0.75,
                           continuous_update=False)
    ao_slide.observe(click, names='value')
    to_slide = FloatSlider(description=r'$t_o$',
                           value=1,
                           min=-2,
                           max=6,
                           continuous_update=False)
    to_slide.observe(click, names='value')

    Vmax_n_slide = FloatSlider(description=r'$V$max$_{n}$',
                               value=150,
                               min=0,
                               max=300,
                               continuous_update=False)
    Vmax_n_slide.observe(click, names='value')
    Vmin_n_slide = FloatSlider(description=r'$V$min$_n$',
                               value=15,
                               min=0,
                               max=30,
                               continuous_update=False)
    Vmin_n_slide.observe(click, names='value')
    a0_slide = FloatSlider(description=r'$a_0$',
                           value=0.4,
                           min=0.0,
                           max=1.5,
                           continuous_update=False)
    a0_slide.observe(click, names='value')
    t0_slide = FloatSlider(description=r'$t_0$',
                           value=0,
                           min=-4,
                           max=8,
                           continuous_update=False)
    t0_slide.observe(click, names='value')

    a1_slide = FloatSlider(description=r'$a_1$',
                           value=0.4,
                           min=0.0,
                           max=8,
                           continuous_update=False)
    a1_slide.observe(click, names='value')
    t1_slide = FloatSlider(description=r'$t_1$',
                           value=0.5,
                           min=-2,
                           max=6,
                           continuous_update=False)
    t1_slide.observe(click, names='value')

    def choice_selected(b):
        global o_obs, o_x, n_obs, n_x, df, trace_o, trace_n
        if (False):
            name = choice.value
            df = pd.read_csv("results_analyse/{}.csv".format(name))
            o_obs = dh.unpack_results(df, 1, 'o', volume=False)
            o_x = np.arange(len(o_obs))
            n_obs = dh.unpack_results(df, 1, 'n', volume=False)
            n_x = np.arange(len(n_obs))
        else:
            iexp = choice.value
            mask = df['i_ind'] == iexp
            if (sum(mask) > 0):
                o_obs = df.loc[mask, 'Oc_size']
                o_x = df.loc[mask, 'pos']
                n_obs = df.loc[mask, 'Ns_size']
                n_x = o_x

                vars_o = 'Vmax_o,t_o,a_o,Vmin_o'.split(',')
                vars_n = 'Vmax_n t0 a0 t1 a1 Vmin_n'.split()
                theta_o = np.median(bu.pull_post(trace_o, vars_o, iexp),
                                    axis=0)
                theta_n = np.median(bu.pull_post(trace_n, vars_n, iexp),
                                    axis=0)
                for slide, val in zip(
                    [Vmax_o_slide, to_slide, ao_slide, Vmin_o_slide], theta_o):
                    slide.value = val
                for slide, val in zip([
                        Vmax_n_slide, t0_slide, a0_slide, t1_slide, a1_slide,
                        Vmin_n_slide
                ], theta_n):
                    slide.value = val

        #rown = "{}_1".format(name)
        #Vmax_n_slide.value= res_df.loc[rown,('Vmax_n','m')]
        #a0_slide.value= res_df.loc[rown,('a0','m')]
        #t0_slide.value= -res_df.loc[rown,('t0','m')]
        #a1_slide.value= res_df.loc[rown,('a1','m')]
        #t1_slide.value= -res_df.loc[rown,('t1','m')]

        #Vmax_o_slide.value= res_df.loc[rown,('Vmax_o','m')]
        #ao_slide.value= res_df.loc[rown,('a_o','m')]
        #to_slide.value= -res_df.loc[rown,('t_o','m')]
        click(None)
        #f(Vmax_slide.value, a0_slide.value, t0_slide.value)
        return

    choice_selected(None)
    choice.observe(choice_selected)

    #display(VBox([mslide,cslide]))
    oocyte_params = widgets.VBox([
        Label(value="Oocyte"), Vmax_o_slide, ao_slide, to_slide, Vmin_o_slide
    ])
    nurse_params = widgets.VBox(
        [Vmax_n_slide, a0_slide, t0_slide, a1_slide, t1_slide, Vmin_n_slide])
    box = widgets.VBox(
        [choice, widgets.HBox([oocyte_params, nurse_params]), out])
    display(box)

    click(None)
Ejemplo n.º 31
0
    with neural_network:
        #fit model
        trace_hier = pm.sample(draws=nsamples_hier,
                               init='advi+adapt_diag',
                               n_init=ninit,
                               tune=ninit // 2,
                               chains=nchains_hier,
                               cores=ncores_hier,
                               nuts_kwargs={'target_accept': 0.90},
                               discard_tuned_samples=True,
                               compute_convergence_checks=True,
                               progressbar=False)
    pm.save_trace(trace_hier, directory=cache_file_hier)
else:
    trace_hier = pm.load_trace(cache_file_hier, model=neural_network)

print("Done...")

if not os.path.exists(cache_file_samples):

    samples_tmp = defaultdict(list)
    samples = {}

    for layer_name in layer_names:
        for mu, sd in zip(
                trace_hier.get_values(layer_name,
                                      burn=nsamples_hier // 2,
                                      combine=True),
                trace_hier.get_values(layer_name + '_sd',
                                      burn=nsamples_hier // 2,
Ejemplo n.º 32
0
from pathlib import Path

import model as m
import twa.data as d
from twa import wide
from twa.constants import *
from twa.plot_utils import plot_cline, plot_nodes
from twa.plot_utils import efficient_autocorr, efficient_trace

plotdir = Path("figures")
diagnostics = True

if not os.path.isdir(plotdir):
    os.makedirs(plotdir)

trace = pm.load_trace(directory="chains", model=m.model)

if diagnostics:
    ar_data = az.from_pymc3(trace=trace)

    # view summary
    df = az.summary(ar_data, var_names=m.all_vars)

    # write summary to disk
    f = open(plotdir / "summary.txt", "w")
    df.to_string(f)
    f.close()

    with az.rc_context(rc={"plot.max_subplots": 80}):
        stem = str(plotdir / "autocorr{:}.png")
        efficient_autocorr(ar_data, var_names=m.all_vars, figstem=stem)
Ejemplo n.º 33
0
    best_model = pkl.load(f)

for disease in diseases:
    use_age = best_model[disease]["use_age"]
    use_eastwest = best_model[disease]["use_eastwest"]
    prediction_region = "bavaria" if disease == "borreliosis" else "germany"
    filename_params = "../data/mcmc_samples/parameters_{}_{}_{}".format(
        disease, use_age, use_eastwest)
    filename_model = "../data/mcmc_samples/model_{}_{}_{}.pkl".format(
        disease, use_age, use_eastwest)

    with open(filename_model, "rb") as f:
        model = pkl.load(f)

    with model:
        trace = pm.load_trace(filename_params)

    fig = plt.figure(figsize=(12, 14))
    grid = GridSpec(1,
                    2,
                    top=0.9,
                    bottom=0.1,
                    left=0.07,
                    right=0.97,
                    hspace=0.25,
                    wspace=0.15)

    W_ia_args = {
        "W_ia": {
            "color": "C1",
            "label": "$W_{IA}$",