Exemple #1
0
    true_std_devs = (20, 5, 0.25)
    noisy_data = gen_data_from_multi_src(eval_model, true_std_devs, plot=True)
    num_samples = 50000
    burnin = int(num_samples / 2)
    num_params = 5
    init_inputs = np.array([[1] * num_params, [2] * num_params])
    cov = np.eye(num_params)

    src_num_pts = (NUM_DATA_PTS, NUM_DATA_PTS, NUM_DATA_PTS)
    src_std_devs = (None, None, None)  # estimate all 3 source std devs
    log_like_args = [src_num_pts, src_std_devs]
    log_like_func = MultiSourceNormal

    priors = [ImproperUniform(0., 6.), ImproperUniform(0., 6.)] + \
             [ImproperUniform(0, None)] * 3 # priors for all 3 source std devs
    vector_mcmc = VectorMCMC(eval_model, noisy_data, priors, log_like_args,
                             log_like_func)

    chain = vector_mcmc.metropolis(init_inputs,
                                   num_samples,
                                   cov,
                                   adapt_interval=200,
                                   adapt_delay=5000,
                                   progress_bar=True)

    plot_mcmc_chain(chain,
                    param_labels=['a', 'b', 'std1', 'std2', 'std3'],
                    burnin=burnin,
                    include_kde=True)
Exemple #2
0
if __name__ == "__main__":

    np.random.seed(2)

    true_mean = 1
    true_std = 2
    n_data_pts = 100
    n_samples = 6000
    burnin = 0
    param_names = ['mean', 'std']

    model = lambda x: np.tile(x[:, 0], (1, n_data_pts))
    data = np.random.normal(1, 2, n_data_pts)
    priors = [ImproperUniform(), ImproperUniform(0, None)]
    x0 = np.array([[-3, 7]])
    cov = np.eye(2) * 0.001

    vmcmc = VectorMCMC(model, data, priors)
    chain = vmcmc.metropolis(x0,
                             n_samples,
                             cov,
                             adapt_interval=100,
                             adapt_delay=1000,
                             progress_bar=True)

    burnin, z = compute_geweke(chain[0], window_pct=10, step_pct=1)

    plot_mcmc_chain(chain, param_names)
    plot_geweke(burnin, z, param_names)