def test_iterator():
    with pm.Model() as model:
        a = pm.Normal("a", shape=1)
        b = pm.HalfNormal("b")
        step1 = pm.NUTS([model.rvs_to_values[a]])
        step2 = pm.Metropolis([model.rvs_to_values[b]])

    step = pm.CompoundStep([step1, step2])

    start = {"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))}
    sampler = ps.ParallelSampler(10, 10, 3, 2, [2, 3, 4], [start] * 3, step, 0, False)
    with sampler:
        for draw in sampler:
            pass
Exemple #2
0
def test_abort(mp_start_method):
    with pm.Model() as model:
        a = pm.Normal("a", shape=1)
        b = pm.HalfNormal("b")
        step1 = pm.NUTS([model.rvs_to_values[a]])
        step2 = pm.Metropolis([model.rvs_to_values[b]])

    step = pm.CompoundStep([step1, step2])

    # on Windows we cannot fork
    if platform.system() == "Windows" and mp_start_method == "fork":
        return
    if mp_start_method == "spawn":
        step_method_pickled = cloudpickle.dumps(step, protocol=-1)
    else:
        step_method_pickled = None

    for abort in [False, True]:
        ctx = multiprocessing.get_context(mp_start_method)
        proc = ps.ProcessAdapter(
            10,
            10,
            step,
            chain=3,
            seed=1,
            mp_ctx=ctx,
            start={
                "a": floatX(np.array([1.0])),
                "b_log__": floatX(np.array(2.0))
            },
            step_method_pickled=step_method_pickled,
        )
        proc.start()
        while True:
            proc.write_next()
            out = ps.ProcessAdapter.recv_draw([proc])
            if out[1]:
                break
        if abort:
            proc.abort()
        proc.join()