def test_merge_traces_diff_lengths(self): with self.model: strace0 = self.backend(self.name) strace0.setup(self.draws, 1) for i in range(self.draws): strace0.record(self.test_point) strace0.close() mtrace0 = base.MultiTrace([self.strace0]) with self.model: strace1 = self.backend(self.name) strace1.setup(2 * self.draws, 1) for i in range(2 * self.draws): strace1.record(self.test_point) strace1.close() mtrace1 = base.MultiTrace([strace1]) with pytest.raises(ValueError): base.merge_traces([mtrace0, mtrace1])
def sample_model(model, step=None, num_samples=MAX_NUM_SAMPLES, advi=False, n_chains=NUM_CHAINS, raw_trace=False, single_chain=True, num_scale1_iters=NUM_SCALE1_ITERS, num_scale0_iters=NUM_SCALE0_ITERS): """ Sample parallel chains from constructed Bayesian model. Returns tuple of Multitrace and diagnostics object. """ sample_chain_with_args = partial(sample_chain, step=step, num_samples=num_samples, advi=advi, num_scale1_iters=num_scale1_iters, num_scale0_iters=num_scale0_iters) diagnostics = None if not advi: if single_chain: trace = sample_chain_with_args(model) diagnostics = get_diagnostics(trace, model, single_chain=True) else: traces = [] for i in range(n_chains): print('chain {} of {}'.format(i + 1, n_chains)) traces.append(sample_chain_with_args(model, chain_i=i)) # copy and rebuild traces list because merge_traces modifies # the first trace in the list trace0 = deepcopy(traces[0]) trace = merge_traces(traces) traces = [trace0] + traces[1:] diagnostics = get_diagnostics(merge_truncated_traces(traces), model, single_chain=False) else: trace = sample_chain_with_args(model) diagnostics = get_diagnostics(trace, model, single_chain=True) if raw_trace: return trace, diagnostics else: return format_trace(trace, to_df=True), diagnostics
def main(): X, Y = generate_sample() with pm.Model() as model: alpha = pm.Normal('alpha', mu=0, sd=20) beta = pm.Normal('beta', mu=0, sd=20) sigma = pm.Uniform('sigma', lower=0) y = pm.Normal('y', mu=beta*X+alpha, sd=sigma, observed=Y) start = pm.find_MAP() step = pm.NUTS(state=start) with model: if (multicore): trace = pm.sample(itenum, step, start=start, njobs=chainnum, random_seed=range(chainnum), progressbar=progress) else: ts = [pm.sample(itenum, step, chain=i, progressbar=progress) for i in range(chainnum)] trace = merge_traces(ts) if (saveimage): pm.traceplot(trace).savefig("simple_linear_trace.png") print "Rhat = {0}".format(pm.gelman_rubin(trace)) t1 = time.clock() print "elapsed time = {0}".format(t1 - t0) #trace if(not multicore): trace=ts[0] with model: pm.traceplot(trace,model.vars) pm.forestplot(trace) with open("simplelinearregression_model.pkl","w") as fpw: pkl.dump(model,fpw) with open("simplelinearregression_trace.pkl","w") as fpw: pkl.dump(trace,fpw) with open("simplelinearregression_model.pkl") as fp: model=pkl.load(fp) with open("simplelinearregression_trace.pkl") as fp: trace=pkl.load(fp)
def model_inference(model, niter=2000, nadvi=200000, ntraceadvi=1000, seed=123, nchains=2): with model: v_params = pm.variational.advi(n=nadvi, random_seed=seed) tracevi = pm.variational.sample_vp(v_params, draws=ntraceadvi, random_seed=seed) traces = [] for chain in range(nchains): step = pm.NUTS(scaling=np.power(model.dict_to_array(v_params.stds), 2), is_cov=True, target_accept=0.95) trace = pm.sample(niter, chain=chain, step=step, random_seed=seed) trace = trace[niter // 2::2] traces.append(trace) trace = merge_traces(traces) return trace, v_params, tracevi
def test_merge_traces_nonunique(self): mtrace0 = base.MultiTrace([self.strace0]) mtrace1 = base.MultiTrace([self.strace1]) with pytest.raises(ValueError): base.merge_traces([mtrace0, mtrace1])
def _mp_sample(njobs, args): p = mp.Pool(njobs) traces = p.map(argsample, args) p.close() return merge_traces(traces)
def test_merge_traces_nonunique(self): mtrace0 = base.MultiTrace([self.strace0]) mtrace1 = base.MultiTrace([self.strace1]) with pytest.raises(ValueError): base.merge_traces([mtrace0, mtrace1])
def test_merge_traces_no_traces(self): with pytest.raises(ValueError): base.merge_traces([])
def merge_truncated_traces(traces): min_chain_length = min(map(len, traces)) truncated_traces = list( map(lambda trace: trace[-min_chain_length:], traces)) return merge_traces(truncated_traces)