def test_iterator(): with pm.Model() as model: a = pm.Normal("a", shape=1) b = pm.HalfNormal("b") step1 = pm.NUTS([model.rvs_to_values[a]]) step2 = pm.Metropolis([model.rvs_to_values[b]]) step = pm.CompoundStep([step1, step2]) start = {"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))} sampler = ps.ParallelSampler(10, 10, 3, 2, [2, 3, 4], [start] * 3, step, 0, False) with sampler: for draw in sampler: pass
def test_abort(mp_start_method): with pm.Model() as model: a = pm.Normal("a", shape=1) b = pm.HalfNormal("b") step1 = pm.NUTS([model.rvs_to_values[a]]) step2 = pm.Metropolis([model.rvs_to_values[b]]) step = pm.CompoundStep([step1, step2]) # on Windows we cannot fork if platform.system() == "Windows" and mp_start_method == "fork": return if mp_start_method == "spawn": step_method_pickled = cloudpickle.dumps(step, protocol=-1) else: step_method_pickled = None for abort in [False, True]: ctx = multiprocessing.get_context(mp_start_method) proc = ps.ProcessAdapter( 10, 10, step, chain=3, seed=1, mp_ctx=ctx, start={ "a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0)) }, step_method_pickled=step_method_pickled, ) proc.start() while True: proc.write_next() out = ps.ProcessAdapter.recv_draw([proc]) if out[1]: break if abort: proc.abort() proc.join()