Esempio n. 1
0
def test_redis_pw_protection():
    sampler = RedisEvalParallelSamplerServerStarter(  # noqa: S106
        password="******", port=8888)

    def simulate_one():
        accepted = np.random.randint(2)
        return Particle(0, {}, 0.1, [], [], accepted)

    sample = sampler.sample_until_n_accepted(10, simulate_one)
    assert 10 == len(sample.get_accepted_population())
    sampler.cleanup()
Esempio n. 2
0
def test_redis_multiprocess():
    sampler = RedisEvalParallelSamplerServerStarter(
        batch_size=3, workers=1, processes_per_worker=1)

    def simulate_one():
        accepted = np.random.randint(2)
        return Particle(0, {}, 0.1, [], [], accepted)

    sample = sampler.sample_until_n_accepted(10, simulate_one)
    assert 10 == len(sample.get_accepted_population())
    sampler.cleanup()
Esempio n. 3
0
def test_redis_catch_error():

    def model(pars):
        if np.random.uniform() < 0.1:
            raise ValueError("error")
        return {'s0': pars['p0'] + 0.2 * np.random.uniform()}

    def distance(s0, s1):
        return abs(s0['s0'] - s1['s0'])

    prior = Distribution(p0=RV("uniform", 0, 10))
    sampler = RedisEvalParallelSamplerServerStarter(
        batch_size=3, workers=1, processes_per_worker=1, port=8775)

    abc = ABCSMC(model, prior, distance, sampler=sampler, population_size=10)

    db_file = "sqlite:///" + os.path.join(tempfile.gettempdir(), "test.db")
    data = {'s0': 2.8}
    abc.new(db_file, data)

    abc.run(minimum_epsilon=.1, max_nr_populations=3)

    sampler.cleanup()
Esempio n. 4
0
def redis_starter_sampler():
    s = RedisEvalParallelSamplerServerStarter(batch_size=5)
    yield s
    s.cleanup()