예제 #1
0
def run_stan_model(stan_file, data, **kwargs):
    """
    Convenience function to compile, sample and diagnose a Stan model.
    
    Notes
    -----
    For prior predictive sampling (or to otherwise
    simulate data), pass `fixed_param=True`.
    https://cmdstanpy.readthedocs.io/en/latest/sample.html#example-generate-data-fixed-param-true
    """
    model = CmdStanModel(stan_file=stan_file)
    model.compile()
    fit = model.sample(data=data, **kwargs)
    fit.diagnose()
    return model, fit
예제 #2
0
def run_bernoulli_fit():
    # specify Stan file, create, compile CmdStanModel object
    bernoulli_path = os.path.join(cmdstan_path(), 'examples', 'bernoulli',
                                  'bernoulli.stan')
    bernoulli_model = CmdStanModel(stan_file=bernoulli_path)
    bernoulli_model.compile()

    # specify data, fit the model
    bernoulli_data = {'N': 10, 'y': [0, 1, 0, 0, 0, 0, 0, 0, 0, 1]}
    # Show progress
    bernoulli_fit = bernoulli_model.sample(chains=4,
                                           cores=2,
                                           data=bernoulli_data,
                                           show_progress=True)

    # summarize the results (wraps CmdStan `bin/stansummary`):
    print(bernoulli_fit.summary())
예제 #3
0

json_data = {
    "J": J - Start,
    "x_distance": x_distance[Start:J].tolist(),
    "y_successes": y_successes_shrunk[Start:J].tolist(),
    "n_attempts": n_attempts_shrunk[Start:J].tolist()
}

fig, (ax2, ax) = plt.subplots(1, 2)
ax.set_facecolor('grey')
ax.set_ylabel('chance in 1')
ax.set_xlabel('putt distance (feet)')

stan_program = CmdStanModel(stan_file='stan/logist.stan')
stan_program.compile()
fit = stan_program.sample(data=json_data, csv_basename='./puttbetlog')
#print(fit.summary())
logistic_color = 'y'
runGolf(fit, 'stan/logist.stan', ax, logistic_color)
a_draws = fit.get_drawset(['a_intercept']).to_numpy()
b_draws = fit.get_drawset(['b_slope']).to_numpy()
ax2.hist(a_draws, color=logistic_color, label='a_intercept')
ax2.hist(b_draws, label='b_slope')
ax2.legend(loc='upper right')
ax2.set_ylabel('number of draws')

stan_program = CmdStanModel(stan_file='stan/mechai.stan')
stan_program.compile()
fit2 = stan_program.sample(data=json_data, csv_basename='./puttbetmech')
runGolf(fit2, 'stan/mechai.stan', ax, 'g')
예제 #4
0
from cmdstanpy import CmdStanModel

from prep import build_stan_data

data = build_stan_data("050401")

model = CmdStanModel(model_name="xrt", stan_file="simple_xrt.stan")
model.compile(force=True)

fit = model.sample(
    data=data,
    chains=1,
    seed=1234,
    iter_warmup=200,
    iter_sampling=200,
    show_progress=True,
)
예제 #5
0

model_dir = Path("stan_code")
data_dir = Path("data")

# ---- data ---- #
eight_school_data = {
    "J": 8,
    "y": np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
    "sigma": np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
}

# ---- model ---- #
stan_file = model_dir / "schools.stan"
stan_model = CmdStanModel(stan_file=stan_file)
stan_model.compile()

# ---- fitting ---- #
stan_fit = stan_model.sample(data=eight_school_data)

# ---- results ---- #
cmdstanpy_data = az.from_cmdstanpy(
    posterior=stan_fit,
    posterior_predictive="y_hat",
    observed_data={"y": eight_school_data["y"]},
    log_likelihood="log_lik",
    coords={"school": np.arange(eight_school_data["J"])},
    dims={
        "theta": ["school"],
        "y": ["school"],
        "log_lik": ["school"],
예제 #6
0
import matplotlib.pyplot as plt
from typing import Dict
import arviz as az

model_dir = Path("stan_code")
data_dir = Path("data")

# ---- data ---- #
bernoulli_data: Dict = {"N": 10, "y": [0, 1, 0, 0, 0, 0, 0, 0, 0, 1]}
data = pd.DataFrame({"y": bernoulli_data["y"]},
                    index=np.arange(bernoulli_data["N"]))

# ---- model ---- #
bernoulli_stan = Path(cmdstan_path()) / "examples/bernoulli/bernoulli.stan"
bernoulli_model = CmdStanModel(stan_file=bernoulli_stan)
bernoulli_model.compile()

# ---- fitting ---- #
bern_fit: CmdStanMCMC = bernoulli_model.sample(
    data=bernoulli_data,
    chains=4,
    cores=1,
    seed=1111,
    show_progress=True,
)

# ---- results ---- #
"""samples = multi-dimensional array
    all draws from all chains arranged as dimensions:
    (draws, chains, columns).
"""